{-# LANGUAGE TypeFamilies #-}

-- | Extraction of parallelism from a SOACs program.  This generates
-- parallel constructs aimed at CPU execution, which in particular may
-- involve ad-hoc irregular nested parallelism.
module Futhark.Pass.ExtractMulticore (extractMulticore) where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bitraversable
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.MC
import Futhark.IR.MC qualified as MC
import Futhark.IR.SOACS hiding
  ( Body,
    Exp,
    LParam,
    Lambda,
    Pat,
    Stm,
  )
import Futhark.IR.SOACS qualified as SOACS
import Futhark.Pass
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ToGPU (injectSOACS)
import Futhark.Tools
import Futhark.Transform.Rename (Rename, renameSomething)
import Futhark.Util (takeLast)
import Futhark.Util.Log

newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a)
  deriving
    ( forall a b. a -> ExtractM b -> ExtractM a
forall a b. (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ExtractM b -> ExtractM a
$c<$ :: forall a b. a -> ExtractM b -> ExtractM a
fmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
$cfmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
Functor,
      Functor ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
$c<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
liftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
$c<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
pure :: forall a. a -> ExtractM a
$cpure :: forall a. a -> ExtractM a
Applicative,
      Applicative ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ExtractM a
$creturn :: forall a. a -> ExtractM a
>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
$c>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
Monad,
      HasScope MC,
      LocalScope MC,
      Monad ExtractM
ExtractM VNameSource
VNameSource -> ExtractM ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> ExtractM ()
$cputNameSource :: VNameSource -> ExtractM ()
getNameSource :: ExtractM VNameSource
$cgetNameSource :: ExtractM VNameSource
MonadFreshNames
    )

-- XXX: throwing away the log here...
instance MonadLogger ExtractM where
  addLog :: Log -> ExtractM ()
addLog Log
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i (Param Attrs
_ VName
p LParamInfo SOACS
t) VName
arr =
  forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
p LParamInfo SOACS
t]) (forall dec. dec -> StmAux dec
defAux ()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
    case LParamInfo SOACS
t of
      Acc {} -> SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      LParamInfo SOACS
_ -> VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase Shape u -> [SubExp]
arrayDims LParamInfo SOACS
t)

mapLambdaToBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Body MC)
mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  let indexings :: [Stm MC]
indexings = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
  Body () Stms MC
stms Result
res <- forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf [Stm MC]
indexings forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
onBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm MC]
indexings forall a. Semigroup a => a -> a -> a
<> Stms MC
stms) Result
res

mapLambdaToKernelBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (KernelBody MC)
mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  Body () Stms MC
stms Result
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs
  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
stms forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  let comm' :: Commutativity
comm'
        | forall {k} (rep :: k). Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam' = Commutativity
Commutative
        | Bool
otherwise = Commutativity
comm
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda MC
lam'' [SubExp]
nes' Shape
shape)

scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp (Scan Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda MC
lam'' [SubExp]
nes' Shape
shape)

histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC)
histToSegBinOp :: HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp (SOACS.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) = do
  ((Lambda SOACS
op', [SubExp]
nes', Shape
shape), Stms MC
stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
  Lambda MC
op'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
op'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, forall {k} (rep :: k).
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
MC.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda MC
op'')

mkSegSpace :: MonadFreshNames m => SubExp -> m (VName, SegSpace)
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w = do
  VName
flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_tid"
  VName
gtid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat [(VName
gtid, SubExp
w)]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
gtid, SegSpace
space)

transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm (WhileLoop VName
cond) = forall {k} (rep :: k). VName -> LoopForm rep
WhileLoop VName
cond
transformLoopForm (ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
params) = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
params

transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
op)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp BasicOp
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
ret (Safety, SrcLoc, [SrcLoc])
info)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
ret (Safety, SrcLoc, [SrcLoc])
info
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)) = do
  let form' :: LoopForm MC
form' = LoopForm SOACS -> LoopForm MC
transformLoopForm LoopForm SOACS
form
  Body MC
body' <-
    forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
merge) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm MC
form') forall a b. (a -> b) -> a -> b
$
      Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm MC
form' Body MC
body'
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret)) =
  forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body SOACS) -> ExtractM (Case (Body MC))
transformCase [Case (Body SOACS)]
cases forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
defbody forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure MatchDec (BranchType SOACS)
ret)
  where
    transformCase :: Case (Body SOACS) -> ExtractM (Case (Body MC))
transformCase (Case [Maybe PrimValue]
vs Body SOACS
body) = forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
  forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {t :: * -> *} {t :: * -> * -> *} {t} {t} {d}.
(Traversable t, Bitraversable t) =>
(t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput [WithAccInput SOACS]
inputs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam)
  where
    transformInput :: (t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput (t
shape, t
arrs, t (t (Lambda SOACS) d)
op) =
      (t
shape,t
arrs,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Lambda SOACS -> ExtractM (Lambda MC)
transformLambda forall (f :: * -> *) a. Applicative f => a -> f a
pure) t (t (Lambda SOACS) d)
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
op)) =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat (LetDec SOACS)
pat (forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux) Op SOACS
op

transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
  forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms =
  case forall {k} (rep :: k). Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms of
    Maybe (Stm SOACS, Stms SOACS)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
    Just (Stm SOACS
stm, Stms SOACS
stms') -> do
      Stms MC
stm_stms <- Stm SOACS -> ExtractM (Stms MC)
transformStm Stm SOACS
stm
      forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
stm_stms forall a b. (a -> b) -> a -> b
$ (Stms MC
stm_stms <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms'

transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody (Body () Stms SOACS
stms Result
res) =
  forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (from :: k1) (to :: k2).
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser Identity SOACS MC
toMC
  where
    toMC :: Rephraser Identity SOACS MC
toMC = forall {k} {k1} (m :: * -> *) (from :: k) (to :: k1).
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS forall {k} (rep :: k) op. op -> MCOp rep op
OtherOp

transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = do
  Body MC
body' <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [FParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body MC
body'

-- Code generation for each parallel basic block is parameterised over
-- how we handle parallelism in the body (whether it's sequentialised
-- by keeping it as SOACs, or turned into SegOps).

data NeedsRename = DoRename | DoNotRename

renameIfNeeded :: Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded :: forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
DoRename = forall a (m :: * -> *). (Rename a, MonadFreshNames m) => a -> m a
renameSomething
renameIfNeeded NeedsRename
DoNotRename = forall (f :: * -> *) a. Applicative f => a -> f a
pure

transformMap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (SegOp () MC)
transformMap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename forall a b. (a -> b) -> a -> b
$
    forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap () SegSpace
space (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody

transformRedomap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [Reduce SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformRedomap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
reds_stms, [SegBinOp MC]
reds') <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp [Reduce SOACS]
reds
  SegOp () MC
op' <-
    forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename forall a b. (a -> b) -> a -> b
$
      forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC]
reds' (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
reds_stms, SegOp () MC
op')

transformHist ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformHist :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
hists_stms, [HistOp MC]
hists') <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp [HistOp SOACS]
hists
  SegOp () MC
op' <-
    forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename forall a b. (a -> b) -> a -> b
$
      forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist () SegSpace
space [HistOp MC]
hists' (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
hists_stms, SegOp () MC
op')

transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat Type
_ Attrs
_ JVP {} =
  forall a. HasCallStack => String -> a
error String
"transformSOAC: unhandled JVP"
transformSOAC Pat Type
_ Attrs
_ VJP {} =
  forall a. HasCallStack => String -> a
error String
"transformSOAC: unhandled VJP"
transformSOAC Pat Type
pat Attrs
_ (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
      SegOp () MC
seq_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w Lambda SOACS
lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam
        then do
          SegOp () MC
par_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Lambda SOACS
lam [VName]
arrs
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      ([Stms MC]
seq_reds_stms, SegOp () MC
seq_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
        then do
          ([Stms MC]
par_reds_stms, SegOp () MC
par_op) <-
            NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_reds_stms forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_reds_stms)
              forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_reds_stms
              forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
      (VName
gtid, SegSpace
space) <- forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
      KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
      ([Stms MC]
scans_stms, [SegBinOp MC]
scans') <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp [Scan SOACS]
scans
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. Monoid a => [a] -> a
mconcat [Stms MC]
scans_stms
          forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm
            ( forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$
                    forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan () SegSpace
space [SegBinOp MC]
scans' (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
            )
  | Bool
otherwise = do
      -- This screma is too complicated for us to immediately do
      -- anything, so split it up and try again.
      Scope SOACS
scope <- forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
      Stms SOACS -> ExtractM (Stms MC)
transformStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k).
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat Type
pat SubExp
w ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformSOAC Pat Type
pat Attrs
_ (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
dests) = do
  (VName
gtid, SegSpace
space) <- forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w

  Body () Stms MC
kstms Result
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
lam [VName]
ivs

  let rets :: [Type]
rets = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, Int, VName)]
dests) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      kres :: [KernelResult]
kres = do
        (Shape
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <- forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests Result
res
        let cs :: Certs
cs =
              forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
            is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
      kbody :: KernelBody MC
kbody = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
kstms [KernelResult]
kres
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$
            forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap () SegSpace
space [Type]
rets KernelBody MC
kbody
transformSOAC Pat Type
pat Attrs
_ (Hist SubExp
w [VName]
arrs [HistOp SOACS]
hists Lambda SOACS
map_lam) = do
  ([Stms MC]
seq_hist_stms, SegOp () MC
seq_op) <-
    NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs

  if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
    then do
      ([Stms MC]
par_hist_stms, SegOp () MC
par_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_hist_stms forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_hist_stms)
          forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
    else
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_hist_stms
          forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pat Type
pat Attrs
_ (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda SOACS
lam) = do
  -- Just remove the stream and transform the resulting stms.
  Scope SOACS
soacs_scope <- forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  Stms SOACS
stream_stms <-
    forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k).
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ Scope SOACS
soacs_scope forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat Type
pat SubExp
w [SubExp]
nes Lambda SOACS
lam [VName]
arrs
  Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stream_stms

transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg Prog SOACS
prog =
  forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope MC) (State VNameSource) (Prog MC)
m forall a. Monoid a => a
mempty)
  where
    ExtractM ReaderT (Scope MC) (State VNameSource) (Prog MC)
m = do
      Stms MC
consts' <- Stms SOACS -> ExtractM (Stms MC)
transformStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog SOACS
prog
      [FunDef MC]
funs' <- forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
consts' forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        Prog SOACS
prog
          { progConsts :: Stms MC
progConsts = Stms MC
consts',
            progFuns :: [FunDef MC]
progFuns = [FunDef MC]
funs'
          }

-- | Transform a program using SOACs to a program in the 'MC'
-- representation, using some amount of flattening.
extractMulticore :: Pass SOACS MC
extractMulticore :: Pass SOACS MC
extractMulticore =
  Pass
    { passName :: String
passName = String
"extract multicore parallelism",
      passDescription :: String
passDescription = String
"Extract multicore parallelism",
      passFunction :: Prog SOACS -> PassM (Prog MC)
passFunction = Prog SOACS -> PassM (Prog MC)
transformProg
    }