{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

-- | Extraction of parallelism from a SOACs program.  This generates
-- parallel constructs aimed at CPU execution, which in particular may
-- involve ad-hoc irregular nested parallelism.
module Futhark.Pass.ExtractMulticore (extractMulticore) where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bitraversable
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.MC
import qualified Futhark.IR.MC as MC
import Futhark.IR.SOACS hiding
  ( Body,
    Exp,
    LParam,
    Lambda,
    Pat,
    Stm,
  )
import qualified Futhark.IR.SOACS as SOACS
import qualified Futhark.IR.SOACS.Simplify as SOACS
import Futhark.Pass
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ToGPU (injectSOACS)
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename (Rename, renameSomething)
import Futhark.Util (takeLast)
import Futhark.Util.Log

newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a)
  deriving
    ( a -> ExtractM b -> ExtractM a
(a -> b) -> ExtractM a -> ExtractM b
(forall a b. (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b. a -> ExtractM b -> ExtractM a) -> Functor ExtractM
forall a b. a -> ExtractM b -> ExtractM a
forall a b. (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ExtractM b -> ExtractM a
$c<$ :: forall a b. a -> ExtractM b -> ExtractM a
fmap :: (a -> b) -> ExtractM a -> ExtractM b
$cfmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
Functor,
      Functor ExtractM
a -> ExtractM a
Functor ExtractM
-> (forall a. a -> ExtractM a)
-> (forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b c.
    (a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM a)
-> Applicative ExtractM
ExtractM a -> ExtractM b -> ExtractM b
ExtractM a -> ExtractM b -> ExtractM a
ExtractM (a -> b) -> ExtractM a -> ExtractM b
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ExtractM a -> ExtractM b -> ExtractM a
$c<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
*> :: ExtractM a -> ExtractM b -> ExtractM b
$c*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
liftA2 :: (a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
<*> :: ExtractM (a -> b) -> ExtractM a -> ExtractM b
$c<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
pure :: a -> ExtractM a
$cpure :: forall a. a -> ExtractM a
$cp1Applicative :: Functor ExtractM
Applicative,
      Applicative ExtractM
a -> ExtractM a
Applicative ExtractM
-> (forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a. a -> ExtractM a)
-> Monad ExtractM
ExtractM a -> (a -> ExtractM b) -> ExtractM b
ExtractM a -> ExtractM b -> ExtractM b
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ExtractM a
$creturn :: forall a. a -> ExtractM a
>> :: ExtractM a -> ExtractM b -> ExtractM b
$c>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
>>= :: ExtractM a -> (a -> ExtractM b) -> ExtractM b
$c>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
$cp1Monad :: Applicative ExtractM
Monad,
      HasScope MC,
      LocalScope MC,
      Monad ExtractM
Applicative ExtractM
ExtractM VNameSource
Applicative ExtractM
-> Monad ExtractM
-> ExtractM VNameSource
-> (VNameSource -> ExtractM ())
-> MonadFreshNames ExtractM
VNameSource -> ExtractM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> ExtractM ()
$cputNameSource :: VNameSource -> ExtractM ()
getNameSource :: ExtractM VNameSource
$cgetNameSource :: ExtractM VNameSource
$cp2MonadFreshNames :: Monad ExtractM
$cp1MonadFreshNames :: Applicative ExtractM
MonadFreshNames
    )

-- XXX: throwing away the log here...
instance MonadLogger ExtractM where
  addLog :: Log -> ExtractM ()
addLog Log
_ = () -> ExtractM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i (Param Attrs
_ VName
p LParamInfo SOACS
t) VName
arr =
  Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
p Type
LParamInfo SOACS
t]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> (BasicOp -> Exp MC) -> BasicOp -> Stm MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp MC
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm MC) -> BasicOp -> Stm MC
forall a b. (a -> b) -> a -> b
$
    case LParamInfo SOACS
t of
      Acc {} -> SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      LParamInfo SOACS
_ -> VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
LParamInfo SOACS
t)

mapLambdaToBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Body MC)
mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  let indexings :: [Stm MC]
indexings = (Param Type -> VName -> Stm MC)
-> [Param Type] -> [VName] -> [Stm MC]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i) (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
  Body () Stms MC
stms Result
res <- [Stm MC] -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf [Stm MC]
indexings (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
onBody (Body SOACS -> ExtractM (Body MC))
-> Body SOACS -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
  Body MC -> ExtractM (Body MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC)) -> Body MC -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> Result -> Body MC
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () ([Stm MC] -> Stms MC
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm MC]
indexings Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
stms) Result
res

mapLambdaToKernelBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (KernelBody MC)
mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  Body () Stms MC
stms Result
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs
  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
  KernelBody MC -> ExtractM (KernelBody MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody MC -> ExtractM (KernelBody MC))
-> KernelBody MC -> ExtractM (KernelBody MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
stms ([KernelResult] -> KernelBody MC)
-> [KernelResult] -> KernelBody MC
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Builder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  let comm' :: Commutativity
comm'
        | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam' = Commutativity
Commutative
        | Bool
otherwise = Commutativity
comm
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda MC
lam'' [SubExp]
nes' Shape
shape)

scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp (Scan Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Builder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda MC
lam'' [SubExp]
nes' Shape
shape)

histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC)
histToSegBinOp :: HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp (SOACS.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) = do
  ((Lambda SOACS
op', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Builder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
  Lambda MC
op'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
op'
  (Stms MC, HistOp MC) -> ExtractM (Stms MC, HistOp MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Shape
-> SubExp -> [VName] -> [SubExp] -> Shape -> Lambda MC -> HistOp MC
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
MC.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda MC
op'')

mkSegSpace :: MonadFreshNames m => SubExp -> m (VName, SegSpace)
mkSegSpace :: SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w = do
  VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_tid"
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat [(VName
gtid, SubExp
w)]
  (VName, SegSpace) -> m (VName, SegSpace)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
gtid, SegSpace
space)

transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm (WhileLoop VName
cond) = VName -> LoopForm MC
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
transformLoopForm (ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
params) = VName -> IntType -> SubExp -> [(LParam MC, VName)] -> LoopForm MC
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam MC, VName)]
params

transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
op)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp MC
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
ret (Safety, SrcLoc, [SrcLoc])
info)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType MC]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp MC
forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
[RetType MC]
ret (Safety, SrcLoc, [SrcLoc])
info
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)) = do
  let form' :: LoopForm MC
form' = LoopForm SOACS -> LoopForm MC
transformLoopForm LoopForm SOACS
form
  Body MC
body' <-
    Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> LoopForm MC -> Scope MC
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm MC
form') (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$
      Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ [(FParam MC, SubExp)] -> LoopForm MC -> Body MC -> Exp MC
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam MC, SubExp)]
merge LoopForm MC
form' Body MC
body'
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Body MC -> Body MC -> IfDec (BranchType MC) -> Exp MC
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond (Body MC -> Body MC -> IfDec ExtType -> Exp MC)
-> ExtractM (Body MC)
-> ExtractM (Body MC -> IfDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
tbranch ExtractM (Body MC -> IfDec ExtType -> Exp MC)
-> ExtractM (Body MC) -> ExtractM (IfDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
fbranch ExtractM (IfDec ExtType -> Exp MC)
-> ExtractM (IfDec ExtType) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec ExtType -> ExtractM (IfDec ExtType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure IfDec ExtType
IfDec (BranchType SOACS)
ret)
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([WithAccInput MC] -> Lambda MC -> Exp MC
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ([WithAccInput MC] -> Lambda MC -> Exp MC)
-> ExtractM [WithAccInput MC] -> ExtractM (Lambda MC -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput SOACS -> ExtractM (WithAccInput MC))
-> [WithAccInput SOACS] -> ExtractM [WithAccInput MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM WithAccInput SOACS -> ExtractM (WithAccInput MC)
forall (t :: * -> *) (t :: * -> * -> *) t t d.
(Traversable t, Bitraversable t) =>
(t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput [WithAccInput SOACS]
inputs ExtractM (Lambda MC -> Exp MC)
-> ExtractM (Lambda MC) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam)
  where
    transformInput :: (t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput (t
shape, t
arrs, t (t (Lambda SOACS) d)
op) =
      (t
shape,t
arrs,) (t (t (Lambda MC) d) -> (t, t, t (t (Lambda MC) d)))
-> ExtractM (t (t (Lambda MC) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t (Lambda SOACS) d -> ExtractM (t (Lambda MC) d))
-> t (t (Lambda SOACS) d) -> ExtractM (t (t (Lambda MC) d))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Lambda SOACS -> ExtractM (Lambda MC))
-> (d -> ExtractM d)
-> t (Lambda SOACS) d
-> ExtractM (t (Lambda MC) d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Lambda SOACS -> ExtractM (Lambda MC)
transformLambda d -> ExtractM d
forall (f :: * -> *) a. Applicative f => a -> f a
pure) t (t (Lambda SOACS) d)
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
op)) =
  (Stm MC -> Stm MC) -> Stms MC -> Stms MC
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm MC -> Stm MC
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat Type
Pat (LetDec SOACS)
pat (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Op SOACS
SOAC SOACS
op

transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
  [LParam MC] -> Body MC -> [Type] -> Lambda MC
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
[LParam MC]
params
    (Body MC -> [Type] -> Lambda MC)
-> ExtractM (Body MC) -> ExtractM ([Type] -> Lambda MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope MC
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body)
    ExtractM ([Type] -> Lambda MC)
-> ExtractM [Type] -> ExtractM (Lambda MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> ExtractM [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms =
  case Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms of
    Maybe (Stm SOACS, Stms SOACS)
Nothing -> Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms MC
forall a. Monoid a => a
mempty
    Just (Stm SOACS
stm, Stms SOACS
stms') -> do
      Stms MC
stm_stms <- Stm SOACS -> ExtractM (Stms MC)
transformStm Stm SOACS
stm
      Stms MC -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
stm_stms (ExtractM (Stms MC) -> ExtractM (Stms MC))
-> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ (Stms MC
stm_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<>) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms'

transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody (Body () Stms SOACS
stms Result
res) =
  BodyDec MC -> Stms MC -> Result -> Body MC
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms MC -> Result -> Body MC)
-> ExtractM (Stms MC) -> ExtractM (Result -> Body MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms ExtractM (Result -> Body MC)
-> ExtractM Result -> ExtractM (Body MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ExtractM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody = Body MC -> ExtractM (Body MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC))
-> (Body SOACS -> Body MC) -> Body SOACS -> ExtractM (Body MC)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Body MC) -> Body MC
forall a. Identity a -> a
runIdentity (Identity (Body MC) -> Body MC)
-> (Body SOACS -> Identity (Body MC)) -> Body SOACS -> Body MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS MC -> Body SOACS -> Identity (Body MC)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser Identity SOACS MC
toMC
  where
    toMC :: Rephraser Identity SOACS MC
toMC = (SOAC MC -> Op MC) -> Rephraser Identity SOACS MC
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC MC -> Op MC
forall rep op. op -> MCOp rep op
OtherOp

transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = do
  Body MC
body' <- Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  FunDef MC -> ExtractM (FunDef MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef MC -> ExtractM (FunDef MC))
-> FunDef MC -> ExtractM (FunDef MC)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType MC]
-> [FParam MC]
-> Body MC
-> FunDef MC
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType MC]
rettype [FParam SOACS]
[FParam MC]
params Body MC
body'

-- Sets the chunk size to one.
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
nes Lambda SOACS
lam = do
  let (Param Type
chunk_param, [Param Type]
acc_params, [Param Type]
slice_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam)

  [Param Type]
inp_params <- [Param Type]
-> (Param Type -> ExtractM (Param Type)) -> ExtractM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param Type]
slice_params ((Param Type -> ExtractM (Param Type)) -> ExtractM [Param Type])
-> (Param Type -> ExtractM (Param Type)) -> ExtractM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
_ VName
p Type
t) ->
    String -> Type -> ExtractM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
p) (Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t)

  Body SOACS
body <- Builder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (Body SOACS) -> ExtractM (Body SOACS))
-> Builder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    Scope SOACS
-> Builder SOACS (Body SOACS) -> Builder SOACS (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
inp_params) (Builder SOACS (Body SOACS) -> Builder SOACS (Body SOACS))
-> Builder SOACS (Body SOACS) -> Builder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      [(Param Type, SubExp)]
-> ((Param Type, SubExp) -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_params [SubExp]
nes) (((Param Type, SubExp) -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((Param Type, SubExp) -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, SubExp
ne) ->
        [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      [(Param Type, Param Type)]
-> ((Param Type, Param Type)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
slice_params [Param Type]
inp_params) (((Param Type, Param Type)
  -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((Param Type, Param Type)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
slice, Param Type
v) ->
        [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
slice] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
v] (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
v)

      (Result
red_res, Result
map_res) <- Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Result -> (Result, Result))
-> BuilderT SOACS (State VNameSource) Result
-> BuilderT SOACS (State VNameSource) (Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)

      [SubExp]
map_res' <- Result
-> (SubExpRes -> BuilderT SOACS (State VNameSource) SubExp)
-> BuilderT SOACS (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
map_res ((SubExpRes -> BuilderT SOACS (State VNameSource) SubExp)
 -> BuilderT SOACS (State VNameSource) [SubExp])
-> (SubExpRes -> BuilderT SOACS (State VNameSource) SubExp)
-> BuilderT SOACS (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) -> do
        VName
v <- String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"map_res" (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) VName)
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        Type
v_t <- VName -> BuilderT SOACS (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
        Certs
-> BuilderT SOACS (State VNameSource) SubExp
-> BuilderT SOACS (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT SOACS (State VNameSource) SubExp
 -> BuilderT SOACS (State VNameSource) SubExp)
-> (BasicOp -> BuilderT SOACS (State VNameSource) SubExp)
-> BasicOp
-> BuilderT SOACS (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"chunk" (Exp SOACS -> BuilderT SOACS (State VNameSource) SubExp)
-> (BasicOp -> Exp SOACS)
-> BasicOp
-> BuilderT SOACS (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT SOACS (State VNameSource) SubExp)
-> BasicOp -> BuilderT SOACS (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
v_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]

      Body SOACS -> Builder SOACS (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> Builder SOACS (Body SOACS))
-> Body SOACS -> Builder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
forall a. Monoid a => a
mempty (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Result
red_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Result
subExpsRes [SubExp]
map_res'

  let ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      map_lam :: Lambda SOACS
map_lam =
        Lambda :: forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
          { lambdaReturnType :: [Type]
lambdaReturnType = [Type]
red_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
map_ts,
            lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
inp_params,
            lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
          }

  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Lambda SOACS
map_lam' <- ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
-> Scope SOACS -> ExtractM (Lambda SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Lambda SOACS -> ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
SOACS.simplifyLambda Lambda SOACS
map_lam) Scope SOACS
soacs_scope

  if Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
    then Lambda SOACS -> ExtractM (Lambda SOACS)
forall (m :: * -> *) rep somerep.
(MonadFreshNames m, Buildable rep, BuilderOps rep,
 LocalScope somerep m, SameScope somerep rep,
 LetDec rep ~ LetDec SOACS, CanBeAliased (Op rep)) =>
Lambda SOACS -> m (Lambda rep)
FOT.transformLambda Lambda SOACS
map_lam'
    else Lambda SOACS -> ExtractM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
map_lam'

-- Code generation for each parallel basic block is parameterised over
-- how we handle parallelism in the body (whether it's sequentialised
-- by keeping it as SOACs, or turned into SegOps).

data NeedsRename = DoRename | DoNotRename

renameIfNeeded :: Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded :: NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
DoRename = a -> ExtractM a
forall a (m :: * -> *). (Rename a, MonadFreshNames m) => a -> m a
renameSomething
renameIfNeeded NeedsRename
DoNotRename = a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

transformMap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (SegOp () MC)
transformMap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
    () -> SegSpace -> [Type] -> KernelBody MC -> SegOp () MC
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap () SegSpace
space (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody

transformRedomap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [Reduce SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformRedomap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
reds_stms, [SegBinOp MC]
reds') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Reduce SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp [Reduce SOACS]
reds
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC]
reds' (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
reds_stms, SegOp () MC
op')

transformHist ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformHist :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
hists_stms, [HistOp MC]
hists') <- [(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC]))
-> ExtractM [(Stms MC, HistOp MC)]
-> ExtractM ([Stms MC], [HistOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp SOACS -> ExtractM (Stms MC, HistOp MC))
-> [HistOp SOACS] -> ExtractM [(Stms MC, HistOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp [HistOp SOACS]
hists
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [HistOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist () SegSpace
space [HistOp MC]
hists' (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
hists_stms, SegOp () MC
op')

transformParStream ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Commutativity ->
  Lambda SOACS ->
  [SubExp] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Stms MC, SegOp () MC)
transformParStream :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  (Stms MC
red_stms, SegBinOp MC
red) <- Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes
  SegOp () MC
op <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC
red] (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  (Stms MC, SegOp () MC) -> ExtractM (Stms MC, SegOp () MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
red_stms, SegOp () MC
op)

transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat Type
pat Attrs
_ (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
      SegOp () MC
seq_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w Lambda SOACS
lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam
        then do
          SegOp () MC
par_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Lambda SOACS
lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      ([Stms MC]
seq_reds_stms, SegOp () MC
seq_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
        then do
          ([Stms MC]
par_reds_stms, SegOp () MC
par_op) <-
            NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_reds_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_reds_stms)
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_reds_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
      (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
      KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
      ([Stms MC]
scans_stms, [SegBinOp MC]
scans') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scan SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Scan SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp [Scan SOACS]
scans
      Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
scans_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm
            ( Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
                Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
                  Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
                    ()
-> SegSpace
-> [SegBinOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan () SegSpace
space [SegBinOp MC]
scans' (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
            )
  | Bool
otherwise = do
      -- This screma is too complicated for us to immediately do
      -- anything, so split it up and try again.
      Scope SOACS
scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> ExtractM (Stms SOACS) -> ExtractM (Stms MC)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
-> SubExp
-> ScremaForm (Rep (BuilderT SOACS ExtractM))
-> [VName]
-> BuilderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat Type
Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
pat SubExp
w ScremaForm (Rep (BuilderT SOACS ExtractM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformSOAC Pat Type
pat Attrs
_ (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
dests) = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w

  Body () Stms MC
kstms Result
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
lam [VName]
ivs

  let rets :: [Type]
rets = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
takeLast ([(Shape, Int, VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, Int, VName)]
dests) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      kres :: [KernelResult]
kres = do
        (Shape
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <- [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests Result
res
        let cs :: Certs
cs =
              ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
            is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
        KernelResult -> [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
      kbody :: KernelBody MC
kbody = BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
kstms [KernelResult]
kres
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
    Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$
      Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
        Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
          Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
            () -> SegSpace -> [Type] -> KernelBody MC -> SegOp () MC
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap () SegSpace
space [Type]
rets KernelBody MC
kbody
transformSOAC Pat Type
pat Attrs
_ (Hist SubExp
w [VName]
arrs [HistOp SOACS]
hists Lambda SOACS
map_lam) = do
  ([Stms MC]
seq_hist_stms, SegOp () MC
seq_op) <-
    NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs

  if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
    then do
      ([Stms MC]
par_hist_stms, SegOp () MC
par_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs
      Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_hist_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_hist_stms)
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
    else
      Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_hist_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pat Type
pat Attrs
attrs (Stream SubExp
w [VName]
arrs (Parallel StreamOrd
_ Commutativity
comm Lambda SOACS
red_lam) [SubExp]
red_nes Lambda SOACS
fold_lam)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes = do
      Lambda SOACS
map_lam <- Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
red_nes Lambda SOACS
fold_lam
      (Stms MC
seq_red_stms, SegOp () MC
seq_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream
          NeedsRename
DoNotRename
          Body SOACS -> ExtractM (Body MC)
sequentialiseBody
          SubExp
w
          Commutativity
comm
          Lambda SOACS
red_lam
          [SubExp]
red_nes
          Lambda SOACS
map_lam
          [VName]
arrs

      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
        then do
          (Stms MC
par_red_stms, SegOp () MC
par_op) <-
            NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            Stms MC
seq_red_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
par_red_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            Stms MC
seq_red_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pat Type
pat Attrs
_ (Stream SubExp
w [VName]
arrs StreamForm SOACS
_ [SubExp]
nes Lambda SOACS
lam) = do
  -- Just remove the stream and transform the resulting stms.
  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms SOACS
stream_stms <-
    (BuilderT SOACS ExtractM ()
 -> Scope SOACS -> ExtractM (Stms SOACS))
-> Scope SOACS
-> BuilderT SOACS ExtractM ()
-> ExtractM (Stms SOACS)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ Scope SOACS
soacs_scope (BuilderT SOACS ExtractM () -> ExtractM (Stms SOACS))
-> BuilderT SOACS ExtractM () -> ExtractM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
      Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
-> SubExp
-> [SubExp]
-> Lambda (Rep (BuilderT SOACS ExtractM))
-> [VName]
-> BuilderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat Type
Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
pat SubExp
w [SubExp]
nes Lambda (Rep (BuilderT SOACS ExtractM))
Lambda SOACS
lam [VName]
arrs
  Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stream_stms

transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) =
  (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC))
-> (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Prog MC)
-> VNameSource -> (Prog MC, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Scope MC) (State VNameSource) (Prog MC)
-> Scope MC -> State VNameSource (Prog MC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope MC) (State VNameSource) (Prog MC)
m Scope MC
forall a. Monoid a => a
mempty)
  where
    ExtractM ReaderT (Scope MC) (State VNameSource) (Prog MC)
m = do
      Stms MC
consts' <- Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
consts
      [FunDef MC]
funs' <- Stms MC -> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
consts' (ExtractM [FunDef MC] -> ExtractM [FunDef MC])
-> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> ExtractM (FunDef MC))
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef [FunDef SOACS]
funs
      Prog MC -> ExtractM (Prog MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog MC -> ExtractM (Prog MC)) -> Prog MC -> ExtractM (Prog MC)
forall a b. (a -> b) -> a -> b
$ Stms MC -> [FunDef MC] -> Prog MC
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms MC
consts' [FunDef MC]
funs'

-- | Transform a program using SOACs to a program in the 'MC'
-- representation, using some amount of flattening.
extractMulticore :: Pass SOACS MC
extractMulticore :: Pass SOACS MC
extractMulticore =
  Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: String
passName = String
"extract multicore parallelism",
      passDescription :: String
passDescription = String
"Extract multicore parallelism",
      passFunction :: Prog SOACS -> PassM (Prog MC)
passFunction = Prog SOACS -> PassM (Prog MC)
transformProg
    }