{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels (extractKernels) where
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.Maybe
import Futhark.IR.GPU
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyStms)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Intragroup
import Futhark.Pass.ExtractKernels.StreamKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Util.Log
import Prelude hiding (log)
extractKernels :: Pass SOACS GPU
=
Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
{ passName :: String
passName = String
"extract kernels",
passDescription :: String
passDescription = String
"Perform kernel extraction",
passFunction :: Prog SOACS -> PassM (Prog GPU)
passFunction = Prog SOACS -> PassM (Prog GPU)
transformProg
}
transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) = do
GPUStms
consts' <- DistribM GPUStms -> PassM GPUStms
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM GPUStms -> PassM GPUStms)
-> DistribM GPUStms -> PassM GPUStms
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
forall a. Monoid a => a
mempty ([Stm SOACS] -> DistribM GPUStms)
-> [Stm SOACS] -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
consts
[FunDef GPU]
funs' <- (FunDef SOACS -> PassM (FunDef GPU))
-> [FunDef SOACS] -> PassM [FunDef GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPU -> FunDef SOACS -> PassM (FunDef GPU)
forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef (Scope GPU -> FunDef SOACS -> PassM (FunDef GPU))
-> Scope GPU -> FunDef SOACS -> PassM (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ GPUStms -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf GPUStms
consts') [FunDef SOACS]
funs
Prog GPU -> PassM (Prog GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog GPU -> PassM (Prog GPU)) -> Prog GPU -> PassM (Prog GPU)
forall a b. (a -> b) -> a -> b
$ GPUStms -> [FunDef GPU] -> Prog GPU
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog GPUStms
consts' [FunDef GPU]
funs'
data State = State
{ State -> VNameSource
stateNameSource :: VNameSource,
State -> Int
stateThresholdCounter :: Int
}
newtype DistribM a = DistribM (RWS (Scope GPU) Log State a)
deriving
( a -> DistribM b -> DistribM a
(a -> b) -> DistribM a -> DistribM b
(forall a b. (a -> b) -> DistribM a -> DistribM b)
-> (forall a b. a -> DistribM b -> DistribM a) -> Functor DistribM
forall a b. a -> DistribM b -> DistribM a
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistribM b -> DistribM a
$c<$ :: forall a b. a -> DistribM b -> DistribM a
fmap :: (a -> b) -> DistribM a -> DistribM b
$cfmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
Functor,
Functor DistribM
a -> DistribM a
Functor DistribM
-> (forall a. a -> DistribM a)
-> (forall a b. DistribM (a -> b) -> DistribM a -> DistribM b)
-> (forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM a)
-> Applicative DistribM
DistribM a -> DistribM b -> DistribM b
DistribM a -> DistribM b -> DistribM a
DistribM (a -> b) -> DistribM a -> DistribM b
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistribM a -> DistribM b -> DistribM a
$c<* :: forall a b. DistribM a -> DistribM b -> DistribM a
*> :: DistribM a -> DistribM b -> DistribM b
$c*> :: forall a b. DistribM a -> DistribM b -> DistribM b
liftA2 :: (a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
<*> :: DistribM (a -> b) -> DistribM a -> DistribM b
$c<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
pure :: a -> DistribM a
$cpure :: forall a. a -> DistribM a
$cp1Applicative :: Functor DistribM
Applicative,
Applicative DistribM
a -> DistribM a
Applicative DistribM
-> (forall a b. DistribM a -> (a -> DistribM b) -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a. a -> DistribM a)
-> Monad DistribM
DistribM a -> (a -> DistribM b) -> DistribM b
DistribM a -> DistribM b -> DistribM b
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistribM a
$creturn :: forall a. a -> DistribM a
>> :: DistribM a -> DistribM b -> DistribM b
$c>> :: forall a b. DistribM a -> DistribM b -> DistribM b
>>= :: DistribM a -> (a -> DistribM b) -> DistribM b
$c>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
$cp1Monad :: Applicative DistribM
Monad,
HasScope GPU,
LocalScope GPU,
MonadState State,
Monad DistribM
Applicative DistribM
a -> DistribM ()
Applicative DistribM
-> Monad DistribM
-> (forall a. ToLog a => a -> DistribM ())
-> (Log -> DistribM ())
-> MonadLogger DistribM
Log -> DistribM ()
forall a. ToLog a => a -> DistribM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> (forall a. ToLog a => a -> m ())
-> (Log -> m ())
-> MonadLogger m
addLog :: Log -> DistribM ()
$caddLog :: Log -> DistribM ()
logMsg :: a -> DistribM ()
$clogMsg :: forall a. ToLog a => a -> DistribM ()
$cp2MonadLogger :: Monad DistribM
$cp1MonadLogger :: Applicative DistribM
MonadLogger
)
instance MonadFreshNames DistribM where
getNameSource :: DistribM VNameSource
getNameSource = (State -> VNameSource) -> DistribM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
putNameSource :: VNameSource -> DistribM ()
putNameSource VNameSource
src = (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
runDistribM ::
(MonadLogger m, MonadFreshNames m) =>
DistribM a ->
m a
runDistribM :: DistribM a -> m a
runDistribM (DistribM RWS (Scope GPU) Log State a
m) = do
(a
x, Log
msgs) <- (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Log), VNameSource)) -> m (a, Log))
-> (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let (a
x, State
s, Log
msgs) = RWS (Scope GPU) Log State a
-> Scope GPU -> State -> (a, State, Log)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope GPU) Log State a
m Scope GPU
forall a. Monoid a => a
mempty (VNameSource -> Int -> State
State VNameSource
src Int
0)
in ((a
x, Log
msgs), State -> VNameSource
stateNameSource State
s)
Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
msgs
a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
transformFunDef ::
(MonadFreshNames m, MonadLogger m) =>
Scope GPU ->
FunDef SOACS ->
m (FunDef GPU)
transformFunDef :: Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef Scope GPU
scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = DistribM (FunDef GPU) -> m (FunDef GPU)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (FunDef GPU) -> m (FunDef GPU))
-> DistribM (FunDef GPU) -> m (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ do
Body GPU
body' <-
Scope GPU -> DistribM (Body GPU) -> DistribM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope GPU
scope Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (DistribM (Body GPU) -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$
KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
forall a. Monoid a => a
mempty Body SOACS
body
FunDef GPU -> DistribM (FunDef GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef GPU -> DistribM (FunDef GPU))
-> FunDef GPU -> DistribM (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType GPU]
-> [FParam GPU]
-> Body GPU
-> FunDef GPU
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType GPU]
rettype [FParam SOACS]
[FParam GPU]
params Body GPU
body'
type GPUStms = Stms GPU
transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body = do
GPUStms
stms <- KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> [Stm SOACS] -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
Body GPU -> DistribM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody GPUStms
stms (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body
transformStms :: KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms :: KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
_ [] =
GPUStms -> DistribM GPUStms
forall (f :: * -> *) a. Applicative f => a -> f a
pure GPUStms
forall a. Monoid a => a
mempty
transformStms KernelPath
path (Stm SOACS
stm : [Stm SOACS]
stms) =
Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm Stm SOACS
stm DistribM (Maybe (Stms SOACS))
-> (Maybe (Stms SOACS) -> DistribM GPUStms) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Stms SOACS)
Nothing -> do
GPUStms
stm' <- KernelPath -> Stm SOACS -> DistribM GPUStms
transformStm KernelPath
path Stm SOACS
stm
GPUStms -> DistribM GPUStms -> DistribM GPUStms
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf GPUStms
stm' (DistribM GPUStms -> DistribM GPUStms)
-> DistribM GPUStms -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
(GPUStms
stm' GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>) (GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path [Stm SOACS]
stms
Just Stms SOACS
stms' ->
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> [Stm SOACS] -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms' [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. Semigroup a => a -> a -> a
<> [Stm SOACS]
stms
unbalancedLambda :: Lambda SOACS -> Bool
unbalancedLambda :: Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
orig_lam =
Names -> Body SOACS -> Bool
forall rep rep. (Op rep ~ SOAC rep) => Names -> Body rep -> Bool
unbalancedBody ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
orig_lam) (Body SOACS -> Bool) -> Body SOACS -> Bool
forall a b. (a -> b) -> a -> b
$
Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
orig_lam
where
subExpBound :: SubExp -> Names -> Bool
subExpBound (Var VName
i) Names
bound = VName
i VName -> Names -> Bool
`nameIn` Names
bound
subExpBound (Constant PrimValue
_) Names
_ = Bool
False
unbalancedBody :: Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
body =
(Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> Exp rep -> Bool
unbalancedStm (Names
bound Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body) (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$
Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
unbalancedStm :: Names -> Exp rep -> Bool
unbalancedStm Names
bound (Op (Stream w _ _ _ _)) =
SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
unbalancedStm Names
bound (Op (Screma w _ _)) =
SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
unbalancedStm Names
_ Op {} =
Bool
False
unbalancedStm Names
_ DoLoop {} = Bool
False
unbalancedStm Names
bound (WithAcc [WithAccInput rep]
_ Lambda rep
lam) =
Names -> Body rep -> Bool
unbalancedBody Names
bound (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
unbalancedStm Names
bound (If SubExp
cond Body rep
tbranch Body rep
fbranch IfDec (BranchType rep)
_) =
SubExp
cond SubExp -> Names -> Bool
`subExpBound` Names
bound
Bool -> Bool -> Bool
&& (Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
tbranch Bool -> Bool -> Bool
|| Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
fbranch)
unbalancedStm Names
_ (BasicOp BasicOp
_) =
Bool
False
unbalancedStm Names
_ (Apply Name
fname [(SubExp, Diet)]
_ [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> Bool
isBuiltInFunction Name
fname
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op soac :: Op SOACS
soac@(Screma _ _ form)))
| Just ([Reduce SOACS]
_, Lambda SOACS
lam2) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
lam2,
Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam2 = do
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
Stms SOACS -> Maybe (Stms SOACS)
forall a. a -> Maybe a
Just (Stms SOACS -> Maybe (Stms SOACS))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Maybe (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Maybe (Stms SOACS))
-> DistribM ((), Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (LetDec (Rep (BuilderT SOACS DistribM)))
-> SOAC (Rep (BuilderT SOACS DistribM))
-> BuilderT SOACS DistribM ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS DistribM)))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS DistribM))
soac) Scope SOACS
types
sequentialisedUnbalancedStm Stm SOACS
_ =
Maybe (Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms SOACS)
forall a. Maybe a
Nothing
cmpSizeLe ::
String ->
SizeClass ->
[SubExp] ->
DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe :: String
-> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), GPUStms)
cmpSizeLe String
desc SizeClass
size_class [SubExp]
to_what = do
Int
x <- (State -> Int) -> DistribM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Int
stateThresholdCounter
(State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateThresholdCounter :: Int
stateThresholdCounter = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
let size_key :: Name
size_key = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
x
Builder GPU (SubExp, Name) -> DistribM ((SubExp, Name), GPUStms)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (SubExp, Name) -> DistribM ((SubExp, Name), GPUStms))
-> Builder GPU (SubExp, Name) -> DistribM ((SubExp, Name), GPUStms)
forall a b. (a -> b) -> a -> b
$ do
SubExp
to_what' <-
String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"comparatee"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
to_what
SubExp
cmp_res <- String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
desc (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
size_key SizeClass
size_class SubExp
to_what'
(SubExp, Name) -> Builder GPU (SubExp, Name)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
cmp_res, Name
size_key)
kernelAlternatives ::
(MonadFreshNames m, HasScope GPU m) =>
Pat Type ->
Body GPU ->
[(SubExp, Body GPU)] ->
m (Stms GPU)
kernelAlternatives :: Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
pat Body GPU
default_body [] = Builder GPU () -> m GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> m GPUStms) -> Builder GPU () -> m GPUStms
forall a b. (a -> b) -> a -> b
$ do
Result
ses <- Body (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep (BuilderT GPU (State VNameSource)))
Body GPU
default_body
[(VName, SubExpRes)]
-> ((VName, SubExpRes) -> Builder GPU ()) -> Builder GPU ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat) Result
ses) (((VName, SubExpRes) -> Builder GPU ()) -> Builder GPU ())
-> ((VName, SubExpRes) -> Builder GPU ()) -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ \(VName
name, SubExpRes Certs
cs SubExp
se) ->
Certs -> Builder GPU () -> Builder GPU ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (Builder GPU () -> Builder GPU ())
-> Builder GPU () -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ())
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
kernelAlternatives Pat Type
pat Body GPU
default_body ((SubExp
cond, Body GPU
alt) : [(SubExp, Body GPU)]
alts) = Builder GPU () -> m GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> m GPUStms) -> Builder GPU () -> m GPUStms
forall a b. (a -> b) -> a -> b
$ do
Pat Type
alts_pat <- ([PatElem Type] -> Pat Type)
-> BuilderT GPU (State VNameSource) [PatElem Type]
-> BuilderT GPU (State VNameSource) (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat (BuilderT GPU (State VNameSource) [PatElem Type]
-> BuilderT GPU (State VNameSource) (Pat Type))
-> ((PatElem Type
-> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) [PatElem Type])
-> (PatElem Type
-> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) (Pat Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PatElem Type]
-> (PatElem Type
-> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) ((PatElem Type -> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) (Pat Type))
-> (PatElem Type
-> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) (Pat Type)
forall a b. (a -> b) -> a -> b
$ \PatElem Type
pe -> do
VName
name <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> BuilderT GPU (State VNameSource) VName)
-> String -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
PatElem Type -> BuilderT GPU (State VNameSource) (PatElem Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem Type
pe {patElemName :: VName
patElemName = VName
name}
GPUStms
alt_stms <- Pat Type
-> Body GPU
-> [(SubExp, Body GPU)]
-> BuilderT GPU (State VNameSource) GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
alts_pat Body GPU
default_body [(SubExp, Body GPU)]
alts
let alt_body :: Body GPU
alt_body = GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody GPUStms
alt_stms (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
alts_pat
Pat (LetDec (Rep (BuilderT GPU (State VNameSource))))
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep (BuilderT GPU (State VNameSource))))
pat (Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ())
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$
SubExp -> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> Exp GPU
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond Body GPU
alt Body GPU
alt_body (IfDec (BranchType GPU) -> Exp GPU)
-> IfDec (BranchType GPU) -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [TypeBase ExtShape NoUniqueness]
-> IfSort -> IfDec (TypeBase ExtShape NoUniqueness)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec ([Type] -> [TypeBase ExtShape NoUniqueness]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes (Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat)) IfSort
IfEquiv
transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda KernelPath
path (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
[LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
[LParam GPU]
params
(Body GPU -> [Type] -> Lambda GPU)
-> DistribM (Body GPU) -> DistribM ([Type] -> Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPU -> DistribM (Body GPU) -> DistribM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body)
DistribM ([Type] -> Lambda GPU)
-> DistribM [Type] -> DistribM (Lambda GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> DistribM [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret
transformStm :: KernelPath -> Stm SOACS -> DistribM GPUStms
transformStm :: KernelPath -> Stm SOACS -> DistribM GPUStms
transformStm KernelPath
_ Stm SOACS
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
Builder GPU () -> DistribM GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM GPUStms)
-> Builder GPU () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder GPU ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac))
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
(Stms SOACS -> DistribM GPUStms)
-> DistribM (Stms SOACS) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS () -> DistribM (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (If SubExp
c Body SOACS
tb Body SOACS
fb IfDec (BranchType SOACS)
rt)) = do
Body GPU
tb' <- KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
tb
Body GPU
fb' <- KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
fb
GPUStms -> DistribM GPUStms
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GPUStms -> DistribM GPUStms) -> GPUStms -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stm GPU -> GPUStms
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> GPUStms) -> Stm GPU -> GPUStms
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec GPU)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> Exp GPU
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
c Body GPU
tb' Body GPU
fb' IfDec (BranchType SOACS)
IfDec (BranchType GPU)
rt
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
Stm GPU -> GPUStms
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> GPUStms) -> (Exp GPU -> Stm GPU) -> Exp GPU -> GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec GPU)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux
(Exp GPU -> GPUStms) -> DistribM (Exp GPU) -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ((WithAccInput SOACS -> WithAccInput GPU)
-> [WithAccInput SOACS] -> [WithAccInput GPU]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput SOACS -> WithAccInput GPU
forall (f :: * -> *) (p :: * -> * -> *) a b c.
(Functor f, Bifunctor p) =>
(a, b, f (p (Lambda SOACS) c)) -> (a, b, f (p (Lambda GPU) c))
transformInput [WithAccInput SOACS]
inputs) (Lambda GPU -> Exp GPU)
-> DistribM (Lambda GPU) -> DistribM (Exp GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda KernelPath
path Lambda SOACS
lam)
where
transformInput :: (a, b, f (p (Lambda SOACS) c)) -> (a, b, f (p (Lambda GPU) c))
transformInput (a
shape, b
arrs, f (p (Lambda SOACS) c)
op) =
(a
shape, b
arrs, (p (Lambda SOACS) c -> p (Lambda GPU) c)
-> f (p (Lambda SOACS) c) -> f (p (Lambda GPU) c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Lambda SOACS -> Lambda GPU)
-> p (Lambda SOACS) c -> p (Lambda GPU) c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Lambda SOACS -> Lambda GPU
soacsLambdaToGPU) f (p (Lambda SOACS) c)
op)
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)) =
Scope GPU -> DistribM GPUStms -> DistribM GPUStms
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (LoopForm SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
form) Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params) (DistribM GPUStms -> DistribM GPUStms)
-> DistribM GPUStms -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
Stm GPU -> GPUStms
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> GPUStms)
-> (Body GPU -> Stm GPU) -> Body GPU -> GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec GPU)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU)
-> (Body GPU -> Exp GPU) -> Body GPU -> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
merge LoopForm GPU
form' (Body GPU -> GPUStms) -> DistribM (Body GPU) -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body
where
params :: [Param DeclType]
params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
WhileLoop VName
cond ->
VName -> LoopForm GPU
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
ps ->
VName -> IntType -> SubExp -> [(LParam GPU, VName)] -> LoopForm GPU
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam GPU, VName)]
ps
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma w arrs form)))
| Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
KernelPath -> MapLoop -> DistribM GPUStms
onMap KernelPath
path (MapLoop -> DistribM GPUStms) -> MapLoop -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w Lambda SOACS
lam [VName]
arrs
transformStm KernelPath
path (Let Pat (LetDec SOACS)
res_pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form)))
| Just [Scan SOACS]
scans <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
Scan Lambda SOACS
scan_lam [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans,
Just BuilderT SOACS DistribM ()
do_iswim <- Pat Type
-> SubExp
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (BuilderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ())
iswim Pat Type
Pat (LetDec SOACS)
res_pat SubExp
w Lambda SOACS
scan_lam ([(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistribM GPUStms)
-> DistribM ((), Stms SOACS) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Certs -> BuilderT SOACS DistribM () -> BuilderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs BuilderT SOACS DistribM ()
do_iswim) Scope SOACS
types
| Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = Builder GPU () -> DistribM GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM GPUStms)
-> Builder GPU () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ do
[SegBinOp GPU]
scan_ops <- [Scan SOACS]
-> (Scan SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU])
-> (Scan SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
nes) -> do
(Lambda SOACS
scan_lam', [SubExp]
nes', Shape
shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT GPU (State VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
scan_lam [SubExp]
nes
let scan_lam'' :: Lambda GPU
scan_lam'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scan_lam'
SegBinOp GPU -> BuilderT GPU (State VNameSource) (SegBinOp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> SegBinOp GPU -> BuilderT GPU (State VNameSource) (SegBinOp GPU)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scan_lam'' [SubExp]
nes' Shape
shape
let map_lam_sequential :: Lambda GPU
map_lam_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
SegLevel
lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] String
"segscan" (ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
GPUStms -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (GPUStms -> Builder GPU ())
-> (GPUStms -> GPUStms) -> GPUStms -> Builder GPU ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPU -> Stm GPU) -> GPUStms -> GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm GPU -> Stm GPU
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs)
(GPUStms -> Builder GPU ())
-> BuilderT GPU (State VNameSource) GPUStms -> Builder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (State VNameSource) GPUStms
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
res_pat Certs
forall a. Monoid a => a
mempty SubExp
w [SegBinOp GPU]
scan_ops Lambda GPU
map_lam_sequential [VName]
arrs [] []
transformStm KernelPath
path (Let Pat (LetDec SOACS)
res_pat StmAux (ExpDec SOACS)
aux (Op (Screma w arrs form)))
| Just [Reduce Commutativity
comm Lambda SOACS
red_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
let comm' :: Commutativity
comm'
| Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_fun = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm,
Just BuilderT SOACS DistribM ()
do_irwim <- Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (BuilderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat Type
Pat (LetDec SOACS)
res_pat SubExp
w Commutativity
comm' Lambda SOACS
red_fun ([(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
Stms SOACS
stms <- (Stms SOACS, Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> a
fst ((Stms SOACS, Stms SOACS) -> Stms SOACS)
-> DistribM (Stms SOACS, Stms SOACS) -> DistribM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS DistribM (Stms SOACS)
-> Scope SOACS -> DistribM (Stms SOACS, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Stms SOACS -> BuilderT SOACS DistribM (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms (Stms SOACS -> BuilderT SOACS DistribM (Stms SOACS))
-> BuilderT SOACS DistribM (Stms SOACS)
-> BuilderT SOACS DistribM (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS DistribM ()
-> BuilderT SOACS DistribM (Stms (Rep (BuilderT SOACS DistribM)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (StmAux ()
-> BuilderT SOACS DistribM () -> BuilderT SOACS DistribM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux BuilderT SOACS DistribM ()
do_irwim)) Scope SOACS
types
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> [Stm SOACS] -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form)))
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
let paralleliseOuter :: DistribM GPUStms
paralleliseOuter = Builder GPU () -> DistribM GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM GPUStms)
-> Builder GPU () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ do
[SegBinOp GPU]
red_ops <- [Reduce SOACS]
-> (Reduce SOACS
-> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU])
-> (Reduce SOACS
-> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes) -> do
(Lambda SOACS
red_lam', [SubExp]
nes', Shape
shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT GPU (State VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
red_lam [SubExp]
nes
let comm' :: Commutativity
comm'
| Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_lam' = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
red_lam'' :: Lambda GPU
red_lam'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam'
SegBinOp GPU -> BuilderT GPU (State VNameSource) (SegBinOp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> SegBinOp GPU -> BuilderT GPU (State VNameSource) (SegBinOp GPU)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda GPU
red_lam'' [SubExp]
nes' Shape
shape
let map_lam_sequential :: Lambda GPU
map_lam_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
SegLevel
lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] String
"segred" (ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
GPUStms -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (GPUStms -> Builder GPU ())
-> (GPUStms -> GPUStms) -> GPUStms -> Builder GPU ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPU -> Stm GPU) -> GPUStms -> GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm GPU -> Stm GPU
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs)
(GPUStms -> Builder GPU ())
-> BuilderT GPU (State VNameSource) GPUStms -> Builder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat Type
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BuilderT GPU (State VNameSource) GPUStms
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl Pat Type
Pat (LetDec SOACS)
pat SubExp
w [SegBinOp GPU]
red_ops Lambda GPU
map_lam_sequential [VName]
arrs
outerParallelBody :: DistribM (Body GPU)
outerParallelBody =
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
(Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM GPUStms
paralleliseOuter DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))
paralleliseInner :: KernelPath -> DistribM GPUStms
paralleliseInner KernelPath
path' = do
(Stm SOACS
mapstm, Stm SOACS
redstm) <-
Pat (LetDec SOACS)
-> (SubExp, [Reduce SOACS], Lambda SOACS, [VName])
-> DistribM (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
arrs)
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path' ([Stm SOACS] -> DistribM GPUStms)
-> (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> DistribM GPUStms)
-> (BuilderT SOACS DistribM () -> DistribM (Stms SOACS))
-> BuilderT SOACS DistribM ()
-> DistribM GPUStms
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (BuilderT SOACS DistribM () -> Scope SOACS -> DistribM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
`runBuilderT_` Scope SOACS
types) (BuilderT SOACS DistribM () -> DistribM GPUStms)
-> BuilderT SOACS DistribM () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
Stms SOACS -> BuilderT SOACS DistribM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms SOACS -> BuilderT SOACS DistribM ())
-> BuilderT SOACS DistribM (Stms SOACS)
-> BuilderT SOACS DistribM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> BuilderT SOACS DistribM (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms ([Stm SOACS] -> Stms SOACS
forall rep. [Stm rep] -> Stms rep
stmsFromList [Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
mapstm, Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
redstm])
innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
(Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
paralleliseInner KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))
if Bool -> Bool
not (Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam)
Bool -> Bool -> Bool
|| Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux
then DistribM GPUStms
paralleliseOuter
else do
((SubExp
outer_suff, Name
outer_suff_key), GPUStms
suff_stms) <-
String
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), GPUStms)
sufficientParallelism String
"suff_outer_redomap" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing
Body GPU
outer_stms <- DistribM (Body GPU)
outerParallelBody
Body GPU
inner_stms <- KernelPath -> DistribM (Body GPU)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
(GPUStms
suff_stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>) (GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
Pat (LetDec SOACS)
pat Body GPU
inner_stms [(SubExp
outer_suff, Body GPU
outer_stms)]
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Stream w arrs Parallel {} [] map_fun)))
| Bool -> Bool
not (Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) = do
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM GPUStms)
-> DistribM ((), Stms SOACS) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Certs -> BuilderT SOACS DistribM () -> BuilderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT SOACS DistribM () -> BuilderT SOACS DistribM ())
-> BuilderT SOACS DistribM () -> BuilderT SOACS DistribM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS DistribM)))
-> SubExp
-> [SubExp]
-> Lambda (Rep (BuilderT SOACS DistribM))
-> [VName]
-> BuilderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec (Rep (BuilderT SOACS DistribM)))
Pat (LetDec SOACS)
pat SubExp
w [] Lambda (Rep (BuilderT SOACS DistribM))
Lambda SOACS
map_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Stream w arrs (Parallel o comm red_fun) nes fold_fun)))
| Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
KernelPath -> DistribM GPUStms
paralleliseOuter KernelPath
path
| Bool
otherwise = do
((SubExp
outer_suff, Name
outer_suff_key), GPUStms
suff_stms) <-
String
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), GPUStms)
sufficientParallelism String
"suff_outer_stream" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing
Body GPU
outer_stms <- KernelPath -> DistribM (Body GPU)
outerParallelBody ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
Body GPU
inner_stms <- KernelPath -> DistribM (Body GPU)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
(GPUStms
suff_stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>)
(GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
Pat (LetDec SOACS)
pat Body GPU
inner_stms [(SubExp
outer_suff, Body GPU
outer_stms)]
where
paralleliseOuter :: KernelPath -> DistribM GPUStms
paralleliseOuter KernelPath
path'
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
red_fun = do
let fold_fun' :: Lambda GPU
fold_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
fold_fun
let ([PatElem Type]
red_pat_elems, [PatElem Type]
concat_pat_elems) =
Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElem Type] -> ([PatElem Type], [PatElem Type]))
-> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec SOACS)
pat
red_pat :: Pat Type
red_pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
red_pat_elems
((SubExp
num_threads, [VName]
red_results), GPUStms
stms) <-
MkSegLevel GPU DistribM
-> [String]
-> [PatElem Type]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> DistribM ((SubExp, [VName]), GPUStms)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
MkSegLevel GPU m
-> [String]
-> [PatElem Type]
-> SubExp
-> Commutativity
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), GPUStms)
streamMap
MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
((PatElem Type -> String) -> [PatElem Type] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> String
baseString (VName -> String)
-> (PatElem Type -> VName) -> PatElem Type -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
red_pat_elems)
[PatElem Type]
concat_pat_elems
SubExp
w
Commutativity
Noncommutative
Lambda GPU
fold_fun'
[SubExp]
nes
[VName]
arrs
ScremaForm SOACS
reduce_soac <- [Reduce SOACS] -> DistribM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm' Lambda SOACS
red_fun [SubExp]
nes]
(GPUStms
stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>)
(GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GPUStms -> DistribM GPUStms -> DistribM GPUStms
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf
GPUStms
stms
( KernelPath -> Stm SOACS -> DistribM GPUStms
transformStm KernelPath
path' (Stm SOACS -> DistribM GPUStms) -> Stm SOACS -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
red_pat StmAux ()
StmAux (ExpDec SOACS)
aux {stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
forall a. Monoid a => a
mempty} (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
num_threads [VName]
red_results ScremaForm SOACS
reduce_soac)
)
| Bool
otherwise = do
let red_fun_sequential :: Lambda GPU
red_fun_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_fun
fold_fun_sequential :: Lambda GPU
fold_fun_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
fold_fun
(Stm GPU -> Stm GPU) -> GPUStms -> GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm GPU -> Stm GPU
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs)
(GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MkSegLevel GPU DistribM
-> Pat Type
-> SubExp
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [SubExp]
-> [VName]
-> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
MkSegLevel GPU m
-> Pat Type
-> SubExp
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [SubExp]
-> [VName]
-> m GPUStms
streamRed
MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
Pat Type
Pat (LetDec SOACS)
pat
SubExp
w
Commutativity
comm'
Lambda GPU
red_fun_sequential
Lambda GPU
fold_fun_sequential
[SubExp]
nes
[VName]
arrs
outerParallelBody :: KernelPath -> DistribM (Body GPU)
outerParallelBody KernelPath
path' =
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
(Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
paralleliseOuter KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))
paralleliseInner :: KernelPath -> DistribM GPUStms
paralleliseInner KernelPath
path' = do
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path' ([Stm SOACS] -> DistribM GPUStms)
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> [Stm SOACS] -> [Stm SOACS]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs) ([Stm SOACS] -> [Stm SOACS])
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM GPUStms)
-> DistribM ((), Stms SOACS) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (LetDec (Rep (BuilderT SOACS DistribM)))
-> SubExp
-> [SubExp]
-> Lambda (Rep (BuilderT SOACS DistribM))
-> [VName]
-> BuilderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec (Rep (BuilderT SOACS DistribM)))
Pat (LetDec SOACS)
pat SubExp
w [SubExp]
nes Lambda (Rep (BuilderT SOACS DistribM))
Lambda SOACS
fold_fun [VName]
arrs) Scope SOACS
types
innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
(Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
paralleliseInner KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))
comm' :: Commutativity
comm'
| Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_fun, StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamOrd
InOrder = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) = do
Scope SOACS
scope <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> [Stm SOACS] -> [Stm SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs) ([Stm SOACS] -> [Stm SOACS])
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM GPUStms)
-> DistribM ((), Stms SOACS) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (LetDec (Rep (BuilderT SOACS DistribM)))
-> SubExp
-> ScremaForm (Rep (BuilderT SOACS DistribM))
-> [VName]
-> BuilderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (LetDec (Rep (BuilderT SOACS DistribM)))
Pat (LetDec SOACS)
pat SubExp
w ScremaForm (Rep (BuilderT SOACS DistribM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op (Stream w arrs Sequential nes fold_fun))) = do
Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM GPUStms)
-> DistribM ((), Stms SOACS) -> DistribM GPUStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (LetDec (Rep (BuilderT SOACS DistribM)))
-> SubExp
-> [SubExp]
-> Lambda (Rep (BuilderT SOACS DistribM))
-> [VName]
-> BuilderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec (Rep (BuilderT SOACS DistribM)))
Pat (LetDec SOACS)
pat SubExp
w [SubExp]
nes Lambda (Rep (BuilderT SOACS DistribM))
Lambda SOACS
fold_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
_ (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter w ivs lam as))) = Builder GPU () -> DistribM GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM GPUStms)
-> Builder GPU () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ do
let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
VName
write_i <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
let ([Shape]
as_ws, [Int]
_, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
kstms :: GPUStms
kstms = Body GPU -> GPUStms
forall rep. Body rep -> Stms rep
bodyStms (Body GPU -> GPUStms) -> Body GPU -> GPUStms
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
krets :: [KernelResult]
krets = do
(Shape
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <- [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
as (Result -> [(Shape, VName, [(Result, SubExpRes)])])
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
let res_cs :: Certs
res_cs =
((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
KernelResult -> [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
res_cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
body :: KernelBody GPU
body = BodyDec GPU -> GPUStms -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () GPUStms
kstms [KernelResult]
krets
inputs :: [KernelInput]
inputs = do
(Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
KernelInput -> [KernelInput]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
(SegOp SegLevel GPU
kernel, GPUStms
stms) <-
MkSegLevel GPU (BuilderT GPU (State VNameSource))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody GPU
-> BuilderT
GPU (State VNameSource) (SegOp (SegOpLevel GPU) GPU, GPUStms)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel
MkSegLevel GPU (BuilderT GPU (State VNameSource))
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
[(VName
write_i, SubExp
w)]
[KernelInput]
inputs
((Shape -> Type -> Type) -> [Shape] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Int -> Type -> Type) -> (Shape -> Int) -> Shape -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
as_ws ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
Pat (LetDec SOACS)
pat)
KernelBody GPU
body
Certs -> Builder GPU () -> Builder GPU ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (Builder GPU () -> Builder GPU ())
-> Builder GPU () -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ do
Stms (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
GPUStms
stms
Pat (LetDec (Rep (BuilderT GPU (State VNameSource))))
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPU (State VNameSource))))
Pat (LetDec SOACS)
pat (Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ())
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPU
kernel
transformStm KernelPath
_ (Let Pat (LetDec SOACS)
orig_pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Hist w imgs ops bucket_fun))) = do
let bfun' :: Lambda GPU
bfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
Builder GPU () -> DistribM GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM GPUStms)
-> Builder GPU () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ do
SegLevel
lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] String
"seghist" (ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
GPUStms -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (GPUStms -> Builder GPU ())
-> BuilderT GPU (State VNameSource) GPUStms -> Builder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Lambda SOACS
-> BuilderT
GPU
(State VNameSource)
(Lambda (Rep (BuilderT GPU (State VNameSource)))))
-> SegOpLevel (Rep (BuilderT GPU (State VNameSource)))
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep (BuilderT GPU (State VNameSource)))
-> [VName]
-> BuilderT
GPU
(State VNameSource)
(Stms (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS
-> BuilderT
GPU
(State VNameSource)
(Lambda (Rep (BuilderT GPU (State VNameSource))))
Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
onLambda SegOpLevel (Rep (BuilderT GPU (State VNameSource)))
SegLevel
lvl Pat Type
Pat (LetDec SOACS)
orig_pat [] [] Certs
cs SubExp
w [HistOp SOACS]
ops Lambda (Rep (BuilderT GPU (State VNameSource)))
Lambda GPU
bfun' [VName]
imgs
where
onLambda :: Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
onLambda = Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> BuilderT GPU (State VNameSource) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
transformStm KernelPath
_ Stm SOACS
stm =
Builder GPU () -> DistribM GPUStms
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM GPUStms)
-> Builder GPU () -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder GPU ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
sufficientParallelism ::
String ->
[SubExp] ->
KernelPath ->
Maybe Int64 ->
DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism :: String
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), GPUStms)
sufficientParallelism String
desc [SubExp]
ws KernelPath
path Maybe Int64
def =
String
-> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), GPUStms)
cmpSizeLe String
desc (KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold KernelPath
path Maybe Int64
def) [SubExp]
ws
worthIntraGroup :: Lambda SOACS -> Bool
worthIntraGroup :: Lambda SOACS -> Bool
worthIntraGroup Lambda SOACS
lam = Body SOACS -> Int
forall rep. (Op rep ~ SOAC rep) => Body rep -> Int
bodyInterest (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
where
bodyInterest :: Body rep -> Int
bodyInterest Body rep
body =
Seq Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest (Stm rep -> Int) -> Seq (Stm rep) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
interest :: Stm rep -> Int
interest Stm rep
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
Int
0 :: Int
| Op (Screma w _ form) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
Just Lambda rep
lam' <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
| Op (Scatter w _ lam' _) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
| DoLoop [(FParam rep, SubExp)]
_ LoopForm rep
_ Body rep
body <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest Body rep
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
| If SubExp
_ Body rep
tbody Body rep
fbody IfDec (BranchType rep)
_ <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Body rep -> Int
bodyInterest Body rep
tbody) (Body rep -> Int
bodyInterest Body rep
fbody)
| Op (Screma w _ (ScremaForm _ _ lam')) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
SubExp -> Int
forall p. Num p => SubExp -> p
zeroIfTooSmall SubExp
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
| Op (Stream _ _ Sequential _ lam') <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest (Body rep -> Int) -> Body rep -> Int
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam'
| Bool
otherwise =
Int
0
where
attrs :: Attrs
attrs = StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec rep) -> Attrs) -> StmAux (ExpDec rep) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
zeroIfTooSmall :: SubExp -> p
zeroIfTooSmall (Constant (IntValue IntValue
x))
| IntValue -> Int64
intToInt64 IntValue
x Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
32 = p
0
zeroIfTooSmall SubExp
_ = p
1
mapLike :: SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam' =
if Bool
sequential_inner
then Int
0
else Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SubExp -> Int
forall p. Num p => SubExp -> p
zeroIfTooSmall SubExp
w) (Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam'))
worthSequentialising :: Lambda SOACS -> Bool
worthSequentialising :: Lambda SOACS -> Bool
worthSequentialising Lambda SOACS
lam = Body SOACS -> Int
forall rep. (Op rep ~ SOAC rep) => Body rep -> Int
bodyInterest (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
where
bodyInterest :: Body rep -> Int
bodyInterest Body rep
body =
Seq Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest (Stm rep -> Int) -> Seq (Stm rep) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
interest :: Stm rep -> Int
interest Stm rep
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
Int
0 :: Int
| Op (Screma _ _ form@(ScremaForm _ _ lam')) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
Maybe (Lambda rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda rep) -> Bool) -> Maybe (Lambda rep) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
if Bool
sequential_inner
then Int
0
else Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
| Op Scatter {} <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Int
0
| DoLoop [(FParam rep, SubExp)]
_ ForLoop {} Body rep
body <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest Body rep
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
| WithAcc [WithAccInput rep]
_ Lambda rep
withacc_lam <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
withacc_lam)
| Op (Screma _ _ form@(ScremaForm _ _ lam')) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
Int -> Int -> Int
forall a. Num a => a -> a -> a
+
case ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form of
Just ([Reduce rep], Lambda rep)
_ -> Int
1
Maybe ([Reduce rep], Lambda rep)
Nothing -> Int
0
| Bool
otherwise =
Int
0
where
attrs :: Attrs
attrs = StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec rep) -> Attrs) -> StmAux (ExpDec rep) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
onTopLevelStms ::
KernelPath ->
Stms SOACS ->
DistNestT GPU DistribM GPUStms
onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT GPU DistribM GPUStms
onTopLevelStms KernelPath
path Stms SOACS
stms =
DistribM GPUStms -> DistNestT GPU DistribM GPUStms
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (DistribM GPUStms -> DistNestT GPU DistribM GPUStms)
-> DistribM GPUStms -> DistNestT GPU DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms KernelPath
path ([Stm SOACS] -> DistribM GPUStms)
-> [Stm SOACS] -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms
onMap :: KernelPath -> MapLoop -> DistribM GPUStms
onMap :: KernelPath -> MapLoop -> DistribM GPUStms
onMap KernelPath
path (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) = do
Scope GPU
types <- DistribM (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
env :: KernelPath -> DistEnv GPU DistribM
env KernelPath
path' =
DistEnv :: forall rep (m :: * -> *).
Nestings
-> Scope rep
-> (Stms SOACS -> DistNestT rep m (Stms rep))
-> (MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
-> (Stm SOACS -> Builder rep (Stms rep))
-> (Lambda SOACS -> Builder rep (Lambda rep))
-> MkSegLevel rep m
-> DistEnv rep m
DistEnv
{ distNest :: Nestings
distNest = Nesting -> Nestings
singleNesting (Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest),
distScope :: Scope GPU
distScope =
Pat Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat Pat Type
pat
Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
types,
distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path',
distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM GPUStms
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM GPUStms
onTopLevelStms KernelPath
path',
distSegLevel :: MkSegLevel GPU DistribM
distSegLevel = MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped,
distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) GPUStms
distOnSOACSStms = GPUStms -> BuilderT GPU (State VNameSource) GPUStms
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GPUStms -> BuilderT GPU (State VNameSource) GPUStms)
-> (Stm SOACS -> GPUStms)
-> Stm SOACS
-> BuilderT GPU (State VNameSource) GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> GPUStms
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> GPUStms)
-> (Stm SOACS -> Stm GPU) -> Stm SOACS -> GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
distOnSOACSLambda :: Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
distOnSOACSLambda = Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> BuilderT GPU (State VNameSource) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
}
exploitInnerParallelism :: KernelPath -> DistribM GPUStms
exploitInnerParallelism KernelPath
path' =
DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') (DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms)
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
DistAcc GPU -> Stms SOACS -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)
let exploitOuterParallelism :: KernelPath -> DistribM GPUStms
exploitOuterParallelism KernelPath
path' = do
let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') (DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms)
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU))
-> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall a b. (a -> b) -> a -> b
$
GPUStms -> DistAcc GPU -> DistAcc GPU
forall rep. Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc (Body GPU -> GPUStms
forall rep. Body rep -> Stms rep
bodyStms (Body GPU -> GPUStms) -> Body GPU -> GPUStms
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam') DistAcc GPU
acc
KernelNest
-> KernelPath
-> (KernelPath -> DistribM GPUStms)
-> (KernelPath -> DistribM GPUStms)
-> Pat Type
-> Lambda SOACS
-> DistribM GPUStms
onMap' (LoopNesting -> KernelNest
newKernel LoopNesting
loopnest) KernelPath
path KernelPath -> DistribM GPUStms
exploitOuterParallelism KernelPath -> DistribM GPUStms
exploitInnerParallelism Pat Type
pat Lambda SOACS
lam
where
acc :: DistAcc GPU
acc =
DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
{ distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
pat, Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
distStms :: GPUStms
distStms = GPUStms
forall a. Monoid a => a
mempty
}
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra Attrs
attrs =
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
mayExploitOuter :: Attrs -> Bool
mayExploitOuter :: Attrs -> Bool
mayExploitOuter Attrs
attrs =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_outer"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
mayExploitIntra :: Attrs -> Bool
mayExploitIntra :: Attrs -> Bool
mayExploitIntra Attrs
attrs =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
intraMinInnerPar :: Int64
intraMinInnerPar :: Int64
intraMinInnerPar = Int64
32
onMap' ::
KernelNest ->
KernelPath ->
(KernelPath -> DistribM (Stms GPU)) ->
(KernelPath -> DistribM (Stms GPU)) ->
Pat Type ->
Lambda SOACS ->
DistribM (Stms GPU)
onMap' :: KernelNest
-> KernelPath
-> (KernelPath -> DistribM GPUStms)
-> (KernelPath -> DistribM GPUStms)
-> Pat Type
-> Lambda SOACS
-> DistribM GPUStms
onMap' KernelNest
loopnest KernelPath
path KernelPath -> DistribM GPUStms
mk_seq_stms KernelPath -> DistribM GPUStms
mk_par_stms Pat Type
pat Lambda SOACS
lam = do
Scope GPU
types <- DistribM (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
intra <-
if Attrs -> Bool
onlyExploitIntra (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux)
Bool -> Bool -> Bool
|| (Lambda SOACS -> Bool
worthIntraGroup Lambda SOACS
lam Bool -> Bool -> Bool
&& Attrs -> Bool
mayExploitIntra Attrs
attrs)
then (ReaderT
(Scope GPU)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
-> Scope GPU
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)))
-> Scope GPU
-> ReaderT
(Scope GPU)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT
(Scope GPU)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
-> Scope GPU
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope GPU
types (ReaderT
(Scope GPU)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)))
-> ReaderT
(Scope GPU)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
forall a b. (a -> b) -> a -> b
$ KernelNest
-> Lambda SOACS
-> ReaderT
(Scope GPU)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
intraGroupParallelise KernelNest
loopnest Lambda SOACS
lam
else Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
forall a. Maybe a
Nothing
case Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
intra of
Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
_ | Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs -> do
Body GPU
seq_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
mk_seq_stms KernelPath
path DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
pat Body GPU
seq_body []
Maybe ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
Nothing
| Just DistribM (SubExp, Name, GPUStms, Body GPU)
m <- Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
mkSeqAlts -> do
(SubExp
outer_suff, Name
outer_suff_key, GPUStms
outer_suff_stms, Body GPU
seq_body) <- DistribM (SubExp, Name, GPUStms, Body GPU)
m
Body GPU
par_body <-
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
(GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
mk_par_stms ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(GPUStms
outer_suff_stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>) (GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
pat Body GPU
par_body [(SubExp
outer_suff, Body GPU
seq_body)]
| Bool
otherwise -> do
Body GPU
par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
mk_par_stms KernelPath
path DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
pat Body GPU
par_body []
Just intra' :: ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
intra'@((SubExp, SubExp)
_, SubExp
_, Log
log, GPUStms
intra_prelude, GPUStms
intra_stms)
| Attrs -> Bool
onlyExploitIntra Attrs
attrs -> do
Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
Body GPU
group_par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody GPUStms
intra_stms Result
res
(GPUStms
intra_prelude GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>) (GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
pat Body GPU
group_par_body []
| Bool
otherwise -> do
Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
case Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
mkSeqAlts of
Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
Nothing -> do
(Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, GPUStms
intra_suff_stms) <-
KernelPath
-> ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
-> DistribM (Body GPU, SubExp, Name, GPUStms)
checkSuffIntraPar KernelPath
path ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
intra'
Body GPU
par_body <-
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
(GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
mk_par_stms ((Name
intra_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(GPUStms
intra_suff_stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>)
(GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives Pat Type
pat Body GPU
par_body [(SubExp
intra_ok, Body GPU
group_par_body)]
Just DistribM (SubExp, Name, GPUStms, Body GPU)
m -> do
(SubExp
outer_suff, Name
outer_suff_key, GPUStms
outer_suff_stms, Body GPU
seq_body) <- DistribM (SubExp, Name, GPUStms, Body GPU)
m
(Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, GPUStms
intra_suff_stms) <-
KernelPath
-> ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
-> DistribM (Body GPU, SubExp, Name, GPUStms)
checkSuffIntraPar ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
intra'
Body GPU
par_body <-
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
(GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
mk_par_stms
( [ (Name
outer_suff_key, Bool
False),
(Name
intra_suff_key, Bool
False)
]
KernelPath -> KernelPath -> KernelPath
forall a. [a] -> [a] -> [a]
++ KernelPath
path
)
DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
((GPUStms
outer_suff_stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<> GPUStms
intra_suff_stms) GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>)
(GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> DistribM GPUStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m GPUStms
kernelAlternatives
Pat Type
pat
Body GPU
par_body
[(SubExp
outer_suff, Body GPU
seq_body), (SubExp
intra_ok, Body GPU
group_par_body)]
where
nest_ws :: [SubExp]
nest_ws = KernelNest -> [SubExp]
kernelNestWidths KernelNest
loopnest
res :: Result
res = [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux (LoopNesting -> StmAux ()) -> LoopNesting -> StmAux ()
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
innermostKernelNesting KernelNest
loopnest
attrs :: Attrs
attrs = StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
mkSeqAlts :: Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
mkSeqAlts
| Lambda SOACS -> Bool
worthSequentialising Lambda SOACS
lam,
Attrs -> Bool
mayExploitOuter Attrs
attrs = DistribM (SubExp, Name, GPUStms, Body GPU)
-> Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
forall a. a -> Maybe a
Just (DistribM (SubExp, Name, GPUStms, Body GPU)
-> Maybe (DistribM (SubExp, Name, GPUStms, Body GPU)))
-> DistribM (SubExp, Name, GPUStms, Body GPU)
-> Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
forall a b. (a -> b) -> a -> b
$ do
((SubExp
outer_suff, Name
outer_suff_key), GPUStms
outer_suff_stms) <- DistribM ((SubExp, Name), GPUStms)
checkSuffOuterPar
Body GPU
seq_body <-
Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
(GPUStms -> Result -> Body GPU)
-> DistribM GPUStms -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM GPUStms
mk_seq_stms ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(SubExp, Name, GPUStms, Body GPU)
-> DistribM (SubExp, Name, GPUStms, Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
outer_suff, Name
outer_suff_key, GPUStms
outer_suff_stms, Body GPU
seq_body)
| Bool
otherwise =
Maybe (DistribM (SubExp, Name, GPUStms, Body GPU))
forall a. Maybe a
Nothing
checkSuffOuterPar :: DistribM ((SubExp, Name), GPUStms)
checkSuffOuterPar =
String
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), GPUStms)
sufficientParallelism String
"suff_outer_par" [SubExp]
nest_ws KernelPath
path Maybe Int64
forall a. Maybe a
Nothing
checkSuffIntraPar :: KernelPath
-> ((SubExp, SubExp), SubExp, Log, GPUStms, GPUStms)
-> DistribM (Body GPU, SubExp, Name, GPUStms)
checkSuffIntraPar
KernelPath
path'
((SubExp
_intra_min_par, SubExp
intra_avail_par), SubExp
group_size, Log
_, GPUStms
intra_prelude, GPUStms
intra_stms) = do
((SubExp
intra_ok, Name
intra_suff_key), GPUStms
intra_suff_stms) <- do
((SubExp
intra_suff, Name
suff_key), GPUStms
check_suff_stms) <-
String
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), GPUStms)
sufficientParallelism
String
"suff_intra_par"
[SubExp
intra_avail_par]
KernelPath
path'
(Int64 -> Maybe Int64
forall a. a -> Maybe a
Just Int64
intraMinInnerPar)
Builder GPU (SubExp, Name) -> DistribM ((SubExp, Name), GPUStms)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (SubExp, Name) -> DistribM ((SubExp, Name), GPUStms))
-> Builder GPU (SubExp, Name) -> DistribM ((SubExp, Name), GPUStms)
forall a b. (a -> b) -> a -> b
$ do
Stms (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
GPUStms
intra_prelude
SubExp
max_group_size <-
String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"max_group_size" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
SizeGroup
SubExp
fits <-
String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"fits" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
group_size SubExp
max_group_size
Stms (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
GPUStms
check_suff_stms
SubExp
intra_ok <- String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_suff_and_fits" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
fits SubExp
intra_suff
(SubExp, Name) -> Builder GPU (SubExp, Name)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
intra_ok, Name
suff_key)
Body GPU
group_par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ GPUStms -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody GPUStms
intra_stms Result
res
(Body GPU, SubExp, Name, GPUStms)
-> DistribM (Body GPU, SubExp, Name, GPUStms)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, GPUStms
intra_suff_stms)
removeUnusedMapResults ::
Pat Type ->
[SubExpRes] ->
Lambda rep ->
Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults :: Pat Type
-> Result -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults (Pat [PatElem Type]
pes) Result
res Lambda rep
lam = do
let ([PatElem Type]
pes', Result
body_res) =
[(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem Type, SubExpRes)] -> ([PatElem Type], Result))
-> [(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, SubExpRes) -> Bool)
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem Type -> Bool
used (PatElem Type -> Bool)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) ([(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)])
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes (Result -> [(PatElem Type, SubExpRes)])
-> Result -> [(PatElem Type, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
[Int]
perm <- (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes' [SubExp] -> [SubExp] -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
([Int], Pat Type, Lambda rep)
-> Maybe ([Int], Pat Type, Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes', Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
body_res}})
where
used :: PatElem Type -> Bool
used PatElem Type
pe = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
onInnerMap ::
KernelPath ->
MapLoop ->
DistAcc GPU ->
DistNestT GPU DistribM (DistAcc GPU)
onInnerMap :: KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path maploop :: MapLoop
maploop@(MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc GPU
acc
| Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
lam,
Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam =
Stm SOACS -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistAcc GPU
acc
| Bool
otherwise =
DistAcc GPU
-> Stm SOACS
-> DistNestT
GPU
DistribM
(Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc GPU
acc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistNestT
GPU
DistribM
(Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU))
-> (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms GPU
post_kernels, Result
res, KernelNest
nest, DistAcc GPU
acc')
| Just ([Int]
perm, Pat Type
pat', Lambda SOACS
lam') <- Pat Type
-> Result -> Lambda SOACS -> Maybe ([Int], Pat Type, Lambda SOACS)
forall rep.
Pat Type
-> Result -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults Pat Type
pat Result
res Lambda SOACS
lam -> do
PostStms GPU -> DistNestT GPU DistribM ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms GPU
post_kernels
[Int]
-> KernelNest
-> DistAcc GPU
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc GPU)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc' Pat Type
pat' Lambda SOACS
lam'
Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
_ -> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop DistAcc GPU
acc
where
discardTargets :: DistAcc rep -> DistAcc rep
discardTargets DistAcc rep
acc' =
DistAcc rep
acc' {distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
forall a. Monoid a => a
mempty, Result
forall a. Monoid a => a
mempty)}
multiVersion :: [Int]
-> KernelNest
-> DistAcc GPU
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc GPU)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc' Pat Type
pat' Lambda SOACS
lam' = do
DistEnv GPU DistribM
dist_env <- DistNestT GPU DistribM (DistEnv GPU DistribM)
forall r (m :: * -> *). MonadReader r m => m r
ask
let extra_scope :: Scope GPU
extra_scope = Targets -> Scope GPU
forall rep. DistRep rep => Targets -> Scope rep
targetsScope (Targets -> Scope GPU) -> Targets -> Scope GPU
forall a b. (a -> b) -> a -> b
$ DistAcc GPU -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc GPU
acc'
GPUStms
stms <- DistribM GPUStms -> DistNestT GPU DistribM GPUStms
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (DistribM GPUStms -> DistNestT GPU DistribM GPUStms)
-> DistribM GPUStms -> DistNestT GPU DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
Scope GPU -> DistribM GPUStms -> DistribM GPUStms
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistribM GPUStms -> DistribM GPUStms)
-> DistribM GPUStms -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ do
let maploop' :: MapLoop
maploop' = Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat Type
pat' StmAux ()
aux SubExp
w Lambda SOACS
lam' [VName]
arrs
exploitInnerParallelism :: KernelPath -> DistribM GPUStms
exploitInnerParallelism KernelPath
path' = do
let dist_env' :: DistEnv GPU DistribM
dist_env' =
DistEnv GPU DistribM
dist_env
{ distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM GPUStms
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM GPUStms
onTopLevelStms KernelPath
path',
distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path'
}
DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU DistribM
dist_env' (DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms)
-> (DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistribM GPUStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting KernelNest
nest (DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU))
-> (DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPU
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms)
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$
DistAcc GPU -> DistAcc GPU
forall rep. DistAcc rep -> DistAcc rep
discardTargets
(DistAcc GPU -> DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop' DistAcc GPU
acc {distStms :: GPUStms
distStms = GPUStms
forall a. Monoid a => a
mempty}
let lam_res' :: Result
lam_res' =
[Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$
Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
lam'' :: Lambda SOACS
lam'' = Lambda SOACS
lam' {lambdaBody :: Body SOACS
lambdaBody = (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam') {bodyResult :: Result
bodyResult = Result
lam_res'}}
map_nesting :: LoopNesting
map_nesting = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat' StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam') [VName]
arrs
nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pat Type
pat', Result
lam_res') LoopNesting
map_nesting KernelNest
nest
(Stm GPU
sequentialised_kernel, GPUStms
nestw_stms) <- Scope GPU
-> DistribM (Stm GPU, GPUStms) -> DistribM (Stm GPU, GPUStms)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistribM (Stm GPU, GPUStms) -> DistribM (Stm GPU, GPUStms))
-> DistribM (Stm GPU, GPUStms) -> DistribM (Stm GPU, GPUStms)
forall a b. (a -> b) -> a -> b
$ do
let sequentialised_lam :: Lambda GPU
sequentialised_lam = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam''
MkSegLevel GPU DistribM
-> KernelNest -> Body GPU -> DistribM (Stm GPU, GPUStms)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped KernelNest
nest' (Body GPU -> DistribM (Stm GPU, GPUStms))
-> Body GPU -> DistribM (Stm GPU, GPUStms)
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
sequentialised_lam
let outer_pat :: Pat Type
outer_pat = LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
(GPUStms
nestw_stms GPUStms -> GPUStms -> GPUStms
forall a. Semigroup a => a -> a -> a
<>)
(GPUStms -> GPUStms) -> DistribM GPUStms -> DistribM GPUStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelNest
-> KernelPath
-> (KernelPath -> DistribM GPUStms)
-> (KernelPath -> DistribM GPUStms)
-> Pat Type
-> Lambda SOACS
-> DistribM GPUStms
onMap'
KernelNest
nest'
KernelPath
path
(DistribM GPUStms -> KernelPath -> DistribM GPUStms
forall a b. a -> b -> a
const (DistribM GPUStms -> KernelPath -> DistribM GPUStms)
-> DistribM GPUStms -> KernelPath -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ GPUStms -> DistribM GPUStms
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GPUStms -> DistribM GPUStms) -> GPUStms -> DistribM GPUStms
forall a b. (a -> b) -> a -> b
$ Stm GPU -> GPUStms
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
sequentialised_kernel)
KernelPath -> DistribM GPUStms
exploitInnerParallelism
Pat Type
outer_pat
Lambda SOACS
lam''
GPUStms -> DistNestT GPU DistribM ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm GPUStms
stms
DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc GPU
acc'