{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractMulticore (extractMulticore) where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bitraversable
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.MC
import qualified Futhark.IR.MC as MC
import Futhark.IR.SOACS hiding
  ( Body,
    Exp,
    LParam,
    Lambda,
    Pattern,
    Stm,
  )
import qualified Futhark.IR.SOACS as SOACS
import qualified Futhark.IR.SOACS.Simplify as SOACS
import Futhark.Pass
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ToGPU (injectSOACS)
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename (Rename, renameSomething)
import Futhark.Util (takeLast)
import Futhark.Util.Log

newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a)
  deriving
    ( (forall a b. (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b. a -> ExtractM b -> ExtractM a) -> Functor ExtractM
forall a b. a -> ExtractM b -> ExtractM a
forall a b. (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ExtractM b -> ExtractM a
$c<$ :: forall a b. a -> ExtractM b -> ExtractM a
fmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
$cfmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
Functor,
      Functor ExtractM
Functor ExtractM
-> (forall a. a -> ExtractM a)
-> (forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b c.
    (a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM a)
-> Applicative ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
$c<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
liftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
$c<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
pure :: forall a. a -> ExtractM a
$cpure :: forall a. a -> ExtractM a
Applicative,
      Applicative ExtractM
Applicative ExtractM
-> (forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a. a -> ExtractM a)
-> Monad ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ExtractM a
$creturn :: forall a. a -> ExtractM a
>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
$c>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
Monad,
      HasScope MC,
      LocalScope MC,
      Monad ExtractM
Applicative ExtractM
ExtractM VNameSource
Applicative ExtractM
-> Monad ExtractM
-> ExtractM VNameSource
-> (VNameSource -> ExtractM ())
-> MonadFreshNames ExtractM
VNameSource -> ExtractM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> ExtractM ()
$cputNameSource :: VNameSource -> ExtractM ()
getNameSource :: ExtractM VNameSource
$cgetNameSource :: ExtractM VNameSource
MonadFreshNames
    )

-- XXX: throwing away the log here...
instance MonadLogger ExtractM where
  addLog :: Log -> ExtractM ()
addLog Log
_ = () -> ExtractM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i (Param VName
p LParamInfo SOACS
t) VName
arr =
  Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
p TypeBase Shape NoUniqueness
LParamInfo SOACS
t]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> (BasicOp -> Exp MC) -> BasicOp -> Stm MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp MC
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Stm MC) -> BasicOp -> Stm MC
forall a b. (a -> b) -> a -> b
$
    case LParamInfo SOACS
t of
      Acc {} -> SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      LParamInfo SOACS
_ -> VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
LParamInfo SOACS
t)

mapLambdaToBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Body MC)
mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  let indexings :: [Stm MC]
indexings = (Param (TypeBase Shape NoUniqueness) -> VName -> Stm MC)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName] -> [Stm MC]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i) (Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
  Body () Stms MC
stms [SubExp]
res <- [Stm MC] -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf [Stm MC]
indexings (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
onBody (Body SOACS -> ExtractM (Body MC))
-> Body SOACS -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam
  Body MC -> ExtractM (Body MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body MC -> ExtractM (Body MC)) -> Body MC -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [SubExp] -> Body MC
forall rep. BodyDec rep -> Stms rep -> [SubExp] -> BodyT rep
Body () ([Stm MC] -> Stms MC
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm MC]
indexings Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
stms) [SubExp]
res

mapLambdaToKernelBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (KernelBody MC)
mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  Body () Stms MC
stms [SubExp]
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs
  KernelBody MC -> ExtractM (KernelBody MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody MC -> ExtractM (KernelBody MC))
-> KernelBody MC -> ExtractM (KernelBody MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
stms ([KernelResult] -> KernelBody MC)
-> [KernelResult] -> KernelBody MC
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) [SubExp]
res

reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Binder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda MC
lam'' [SubExp]
nes' Shape
shape)

scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp (Scan Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Binder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda MC
lam'' [SubExp]
nes' Shape
shape)

histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC)
histToSegBinOp :: HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp (SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) = do
  ((Lambda SOACS
op', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Binder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
  Lambda MC
op'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
op'
  (Stms MC, HistOp MC) -> ExtractM (Stms MC, HistOp MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
stms, SubExp
-> SubExp -> [VName] -> [SubExp] -> Shape -> Lambda MC -> HistOp MC
forall rep.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
MC.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda MC
op'')

mkSegSpace :: MonadFreshNames m => SubExp -> m (VName, SegSpace)
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w = do
  VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_tid"
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat [(VName
gtid, SubExp
w)]
  (VName, SegSpace) -> m (VName, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
gtid, SegSpace
space)

transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm (WhileLoop VName
cond) = VName -> LoopForm MC
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
transformLoopForm (ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
params) = VName -> IntType -> SubExp -> [(LParam MC, VName)] -> LoopForm MC
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam MC, VName)]
params

transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
op)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp MC
forall rep. BasicOp -> ExpT rep
BasicOp BasicOp
op
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
ret (Safety, SrcLoc, [SrcLoc])
info)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType MC]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp MC
forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT rep
Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
[RetType MC]
ret (Safety, SrcLoc, [SrcLoc])
info
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body SOACS
body)) = do
  let form' :: LoopForm MC
form' = LoopForm SOACS -> LoopForm MC
transformLoopForm LoopForm SOACS
form
  Body MC
body' <-
    Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
      ( [Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx)
          Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val)
          Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> LoopForm MC -> Scope MC
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm MC
form'
      )
      (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ [(FParam MC, SubExp)]
-> [(FParam MC, SubExp)] -> LoopForm MC -> Body MC -> Exp MC
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam MC, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam MC, SubExp)]
val LoopForm MC
form' Body MC
body'
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Body MC -> Body MC -> IfDec (BranchType MC) -> Exp MC
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond (Body MC -> Body MC -> IfDec ExtType -> Exp MC)
-> ExtractM (Body MC)
-> ExtractM (Body MC -> IfDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
tbranch ExtractM (Body MC -> IfDec ExtType -> Exp MC)
-> ExtractM (Body MC) -> ExtractM (IfDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
fbranch ExtractM (IfDec ExtType -> Exp MC)
-> ExtractM (IfDec ExtType) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec ExtType -> ExtractM (IfDec ExtType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure IfDec ExtType
IfDec (BranchType SOACS)
ret)
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
lam)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(Shape, [VName], Maybe (Lambda MC, [SubExp]))]
-> Lambda MC -> Exp MC
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc ([(Shape, [VName], Maybe (Lambda MC, [SubExp]))]
 -> Lambda MC -> Exp MC)
-> ExtractM [(Shape, [VName], Maybe (Lambda MC, [SubExp]))]
-> ExtractM (Lambda MC -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
 -> ExtractM (Shape, [VName], Maybe (Lambda MC, [SubExp])))
-> [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ExtractM [(Shape, [VName], Maybe (Lambda MC, [SubExp]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ExtractM (Shape, [VName], Maybe (Lambda MC, [SubExp]))
forall {t :: * -> *} {t :: * -> * -> *} {t} {t} {d}.
(Traversable t, Bitraversable t) =>
(t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs ExtractM (Lambda MC -> Exp MC)
-> ExtractM (Lambda MC) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam)
  where
    transformInput :: (t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput (t
shape, t
arrs, t (t (Lambda SOACS) d)
op) =
      (t
shape,t
arrs,) (t (t (Lambda MC) d) -> (t, t, t (t (Lambda MC) d)))
-> ExtractM (t (t (Lambda MC) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t (Lambda SOACS) d -> ExtractM (t (Lambda MC) d))
-> t (t (Lambda SOACS) d) -> ExtractM (t (t (Lambda MC) d))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Lambda SOACS -> ExtractM (Lambda MC))
-> (d -> ExtractM d)
-> t (Lambda SOACS) d
-> ExtractM (t (Lambda MC) d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Lambda SOACS -> ExtractM (Lambda MC)
transformLambda d -> ExtractM d
forall (f :: * -> *) a. Applicative f => a -> f a
pure) t (t (Lambda SOACS) d)
op
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
op)) =
  (Stm MC -> Stm MC) -> Stms MC -> Stms MC
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm MC -> Stm MC
forall rep. Certificates -> Stm rep -> Stm rep
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern SOACS -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pattern SOACS
pat (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Op SOACS
SOAC SOACS
op

transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda (Lambda [LParam SOACS]
params Body SOACS
body [TypeBase Shape NoUniqueness]
ret) =
  [LParam MC]
-> Body MC -> [TypeBase Shape NoUniqueness] -> Lambda MC
forall rep.
[LParam rep]
-> BodyT rep -> [TypeBase Shape NoUniqueness] -> LambdaT rep
Lambda [LParam SOACS]
[LParam MC]
params
    (Body MC -> [TypeBase Shape NoUniqueness] -> Lambda MC)
-> ExtractM (Body MC)
-> ExtractM ([TypeBase Shape NoUniqueness] -> Lambda MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope MC
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
params) (Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body)
    ExtractM ([TypeBase Shape NoUniqueness] -> Lambda MC)
-> ExtractM [TypeBase Shape NoUniqueness] -> ExtractM (Lambda MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase Shape NoUniqueness]
-> ExtractM [TypeBase Shape NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase Shape NoUniqueness]
ret

transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms =
  case Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms of
    Maybe (Stm SOACS, Stms SOACS)
Nothing -> Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms MC
forall a. Monoid a => a
mempty
    Just (Stm SOACS
stm, Stms SOACS
stms') -> do
      Stms MC
stm_stms <- Stm SOACS -> ExtractM (Stms MC)
transformStm Stm SOACS
stm
      Stms MC -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
stm_stms (ExtractM (Stms MC) -> ExtractM (Stms MC))
-> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ (Stms MC
stm_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<>) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms'

transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody (Body () Stms SOACS
stms [SubExp]
res) =
  BodyDec MC -> Stms MC -> [SubExp] -> Body MC
forall rep. BodyDec rep -> Stms rep -> [SubExp] -> BodyT rep
Body () (Stms MC -> [SubExp] -> Body MC)
-> ExtractM (Stms MC) -> ExtractM ([SubExp] -> Body MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms ExtractM ([SubExp] -> Body MC)
-> ExtractM [SubExp] -> ExtractM (Body MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> ExtractM [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
res

sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody = Body MC -> ExtractM (Body MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC))
-> (Body SOACS -> Body MC) -> Body SOACS -> ExtractM (Body MC)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Body MC) -> Body MC
forall a. Identity a -> a
runIdentity (Identity (Body MC) -> Body MC)
-> (Body SOACS -> Identity (Body MC)) -> Body SOACS -> Body MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS MC -> Body SOACS -> Identity (Body MC)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser Identity SOACS MC
toMC
  where
    toMC :: Rephraser Identity SOACS MC
toMC = (SOAC MC -> Op MC) -> Rephraser Identity SOACS MC
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC MC -> Op MC
forall rep op. op -> MCOp rep op
OtherOp

transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = do
  Body MC
body' <- Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  FunDef MC -> ExtractM (FunDef MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef MC -> ExtractM (FunDef MC))
-> FunDef MC -> ExtractM (FunDef MC)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType MC]
-> [FParam MC]
-> Body MC
-> FunDef MC
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> BodyT rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType MC]
rettype [FParam SOACS]
[FParam MC]
params Body MC
body'

-- Sets the chunk size to one.
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
nes Lambda SOACS
lam = do
  let (Param (TypeBase Shape NoUniqueness)
chunk_param, [Param (TypeBase Shape NoUniqueness)]
acc_params, [Param (TypeBase Shape NoUniqueness)]
slice_params) =
        Int
-> [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam)

  [Param (TypeBase Shape NoUniqueness)]
inp_params <- [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness)
    -> ExtractM (Param (TypeBase Shape NoUniqueness)))
-> ExtractM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape NoUniqueness)]
slice_params ((Param (TypeBase Shape NoUniqueness)
  -> ExtractM (Param (TypeBase Shape NoUniqueness)))
 -> ExtractM [Param (TypeBase Shape NoUniqueness)])
-> (Param (TypeBase Shape NoUniqueness)
    -> ExtractM (Param (TypeBase Shape NoUniqueness)))
-> ExtractM [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param VName
p TypeBase Shape NoUniqueness
t) ->
    String
-> TypeBase Shape NoUniqueness
-> ExtractM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
p) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
t)

  Body SOACS
body <- Binder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall rep (m :: * -> *) somerep.
(Bindable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Binder rep (Body rep) -> m (Body rep)
runBodyBinder (Binder SOACS (Body SOACS) -> ExtractM (Body SOACS))
-> Binder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    Scope SOACS
-> Binder SOACS (Body SOACS) -> Binder SOACS (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase Shape NoUniqueness)]
inp_params) (Binder SOACS (Body SOACS) -> Binder SOACS (Body SOACS))
-> Binder SOACS (Body SOACS) -> Binder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_param] (Exp (Rep (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      [(Param (TypeBase Shape NoUniqueness), SubExp)]
-> ((Param (TypeBase Shape NoUniqueness), SubExp)
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
acc_params [SubExp]
nes) (((Param (TypeBase Shape NoUniqueness), SubExp)
  -> BinderT SOACS (State VNameSource) ())
 -> BinderT SOACS (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness), SubExp)
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, SubExp
ne) ->
        [VName]
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Rep (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      [(Param (TypeBase Shape NoUniqueness),
  Param (TypeBase Shape NoUniqueness))]
-> ((Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [(Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
slice_params [Param (TypeBase Shape NoUniqueness)]
inp_params) (((Param (TypeBase Shape NoUniqueness),
   Param (TypeBase Shape NoUniqueness))
  -> BinderT SOACS (State VNameSource) ())
 -> BinderT SOACS (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
slice, Param (TypeBase Shape NoUniqueness)
v) ->
        [VName]
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
slice] (Exp (Rep (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
v] (Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
v)

      ([SubExp]
red_res, [SubExp]
map_res) <- Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([SubExp] -> ([SubExp], [SubExp]))
-> BinderT SOACS (State VNameSource) [SubExp]
-> BinderT SOACS (State VNameSource) ([SubExp], [SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m [SubExp]
bodyBind (Lambda SOACS -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam)

      [SubExp]
map_res' <- [SubExp]
-> (SubExp -> BinderT SOACS (State VNameSource) SubExp)
-> BinderT SOACS (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExp]
map_res ((SubExp -> BinderT SOACS (State VNameSource) SubExp)
 -> BinderT SOACS (State VNameSource) [SubExp])
-> (SubExp -> BinderT SOACS (State VNameSource) SubExp)
-> BinderT SOACS (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \SubExp
se -> do
        VName
v <- String
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"map_res" (Exp (Rep (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) VName)
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        TypeBase Shape NoUniqueness
v_t <- VName
-> BinderT SOACS (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
        String
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"chunk" (Exp (Rep (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) SubExp)
-> Exp (Rep (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]

      Body SOACS -> Binder SOACS (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> Binder SOACS (Body SOACS))
-> Body SOACS -> Binder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body SOACS
forall rep. Bindable rep => [SubExp] -> Body rep
resultBody ([SubExp] -> Body SOACS) -> [SubExp] -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
red_res [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
map_res'

  let ([TypeBase Shape NoUniqueness]
red_ts, [TypeBase Shape NoUniqueness]
map_ts) = Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      map_lam :: Lambda SOACS
map_lam =
        Lambda :: forall rep.
[LParam rep]
-> BodyT rep -> [TypeBase Shape NoUniqueness] -> LambdaT rep
Lambda
          { lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
red_ts [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [TypeBase Shape NoUniqueness]
map_ts,
            lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
inp_params,
            lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
          }

  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Lambda SOACS
map_lam' <- ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
-> Scope SOACS -> ExtractM (Lambda SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Lambda SOACS -> ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
SOACS.simplifyLambda Lambda SOACS
map_lam) Scope SOACS
soacs_scope

  if Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
    then Lambda SOACS -> ExtractM (Lambda SOACS)
forall (m :: * -> *) rep somerep.
(MonadFreshNames m, Bindable rep, BinderOps rep,
 LocalScope somerep m, SameScope somerep rep,
 LetDec rep ~ LetDec SOACS, CanBeAliased (Op rep)) =>
Lambda SOACS -> m (Lambda rep)
FOT.transformLambda Lambda SOACS
map_lam'
    else Lambda SOACS -> ExtractM (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda SOACS
map_lam'

-- Code generation for each parallel basic block is parameterised over
-- how we handle parallelism in the body (whether it's sequentialised
-- by keeping it as SOACs, or turned into SegOps).

data NeedsRename = DoRename | DoNotRename

renameIfNeeded :: Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded :: forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
DoRename = a -> ExtractM a
forall a (m :: * -> *). (Rename a, MonadFreshNames m) => a -> m a
renameSomething
renameIfNeeded NeedsRename
DoNotRename = a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

transformMap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (SegOp () MC)
transformMap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
    ()
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap () SegSpace
space (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody

transformRedomap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [Reduce SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformRedomap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
reds_stms, [SegBinOp MC]
reds') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Reduce SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp [Reduce SOACS]
reds
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC]
reds' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms MC]
reds_stms, SegOp () MC
op')

transformHist ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformHist :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
hists_stms, [HistOp MC]
hists') <- [(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC]))
-> ExtractM [(Stms MC, HistOp MC)]
-> ExtractM ([Stms MC], [HistOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp SOACS -> ExtractM (Stms MC, HistOp MC))
-> [HistOp SOACS] -> ExtractM [(Stms MC, HistOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp [HistOp SOACS]
hists
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [HistOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist () SegSpace
space [HistOp MC]
hists' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms MC]
hists_stms, SegOp () MC
op')

transformParStream ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Commutativity ->
  Lambda SOACS ->
  [SubExp] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Stms MC, SegOp () MC)
transformParStream :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  (Stms MC
red_stms, SegBinOp MC
red) <- Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes
  SegOp () MC
op <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC
red] (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  (Stms MC, SegOp () MC) -> ExtractM (Stms MC, SegOp () MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
red_stms, SegOp () MC
op)

transformSOAC :: Pattern SOACS -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC :: Pattern SOACS -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pattern SOACS
pat Attrs
_ (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
    SegOp () MC
seq_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w Lambda SOACS
lam [VName]
arrs
    if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam
      then do
        SegOp () MC
par_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Lambda SOACS
lam [VName]
arrs
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
      else Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
    ([Stms MC]
seq_reds_stms, SegOp () MC
seq_op) <-
      NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
    if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
      then do
        ([Stms MC]
par_reds_stms, SegOp () MC
par_op) <-
          NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_reds_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_reds_stms)
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
      else
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_reds_stms
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
    (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
    KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
    ([Stms MC]
scans_stms, [SegBinOp MC]
scans') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scan SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Scan SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp [Scan SOACS]
scans
    Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
      [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
scans_stms
        Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm
          ( Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
              Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
                Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
                  ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan () SegSpace
space [SegBinOp MC]
scans' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
          )
  | Bool
otherwise = do
    -- This screma is too complicated for us to immediately do
    -- anything, so split it up and try again.
    Scope SOACS
scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
    Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> ExtractM (Stms SOACS) -> ExtractM (Stms MC)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BinderT rep m () -> Scope rep -> m (Stms rep)
runBinderT_ (Pattern (Rep (BinderT SOACS ExtractM))
-> SubExp
-> ScremaForm (Rep (BinderT SOACS ExtractM))
-> [VName]
-> BinderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBinder m, Op (Rep m) ~ SOAC (Rep m), Bindable (Rep m)) =>
Pattern (Rep m) -> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pattern (Rep (BinderT SOACS ExtractM))
Pattern SOACS
pat SubExp
w ScremaForm (Rep (BinderT SOACS ExtractM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformSOAC Pattern SOACS
pat Attrs
_ (Scatter SubExp
w Lambda SOACS
lam [VName]
ivs [(Shape, Int, VName)]
dests) = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w

  Body () Stms MC
kstms [SubExp]
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
lam [VName]
ivs

  let rets :: [TypeBase Shape NoUniqueness]
rets = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
takeLast ([(Shape, Int, VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, Int, VName)]
dests) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      kres :: [KernelResult]
kres = do
        (Shape
a_w, VName
a, [([SubExp], SubExp)]
is_vs) <-
          [(Shape, Int, VName)]
-> [SubExp] -> [(Shape, VName, [([SubExp], SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests [SubExp]
res
        KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Shape
a_w VName
a [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is, SubExp
v) | ([SubExp]
is, SubExp
v) <- [([SubExp], SubExp)]
is_vs]
      kbody :: KernelBody MC
kbody = BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
kstms [KernelResult]
kres
  Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
    Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$
      Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
        Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
          Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
            ()
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap () SegSpace
space [TypeBase Shape NoUniqueness]
rets KernelBody MC
kbody
transformSOAC Pattern SOACS
pat Attrs
_ (Hist SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs) = do
  ([Stms MC]
seq_hist_stms, SegOp () MC
seq_op) <-
    NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs

  if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
    then do
      ([Stms MC]
par_hist_stms, SegOp () MC
par_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs
      Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_hist_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_hist_stms)
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
    else
      Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_hist_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pattern SOACS
pat Attrs
attrs (Stream SubExp
w [VName]
arrs (Parallel StreamOrd
_ Commutativity
comm Lambda SOACS
red_lam) [SubExp]
red_nes Lambda SOACS
fold_lam)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes = do
    Lambda SOACS
map_lam <- Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
red_nes Lambda SOACS
fold_lam
    (Stms MC
seq_red_stms, SegOp () MC
seq_op) <-
      NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream
        NeedsRename
DoNotRename
        Body SOACS -> ExtractM (Body MC)
sequentialiseBody
        SubExp
w
        Commutativity
comm
        Lambda SOACS
red_lam
        [SubExp]
red_nes
        Lambda SOACS
map_lam
        [VName]
arrs

    if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
      then do
        (Stms MC
par_red_stms, SegOp () MC
par_op) <-
          NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          Stms MC
seq_red_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
par_red_stms
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
      else
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          Stms MC
seq_red_stms
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> ExpT rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pattern SOACS
pat Attrs
_ (Stream SubExp
w [VName]
arrs StreamForm SOACS
_ [SubExp]
nes Lambda SOACS
lam) = do
  -- Just remove the stream and transform the resulting stms.
  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms SOACS
stream_stms <-
    (BinderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS))
-> Scope SOACS
-> BinderT SOACS ExtractM ()
-> ExtractM (Stms SOACS)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BinderT rep m () -> Scope rep -> m (Stms rep)
runBinderT_ Scope SOACS
soacs_scope (BinderT SOACS ExtractM () -> ExtractM (Stms SOACS))
-> BinderT SOACS ExtractM () -> ExtractM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
      Pattern (Rep (BinderT SOACS ExtractM))
-> SubExp
-> [SubExp]
-> LambdaT (Rep (BinderT SOACS ExtractM))
-> [VName]
-> BinderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m)) =>
Pattern (Rep m)
-> SubExp -> [SubExp] -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Rep (BinderT SOACS ExtractM))
Pattern SOACS
pat SubExp
w [SubExp]
nes LambdaT (Rep (BinderT SOACS ExtractM))
Lambda SOACS
lam [VName]
arrs
  Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stream_stms

transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) =
  (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC))
-> (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Prog MC)
-> VNameSource -> (Prog MC, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Scope MC) (State VNameSource) (Prog MC)
-> Scope MC -> State VNameSource (Prog MC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope MC) (State VNameSource) (Prog MC)
m Scope MC
forall a. Monoid a => a
mempty)
  where
    ExtractM ReaderT (Scope MC) (State VNameSource) (Prog MC)
m = do
      Stms MC
consts' <- Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
consts
      [FunDef MC]
funs' <- Stms MC -> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
consts' (ExtractM [FunDef MC] -> ExtractM [FunDef MC])
-> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> ExtractM (FunDef MC))
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef [FunDef SOACS]
funs
      Prog MC -> ExtractM (Prog MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog MC -> ExtractM (Prog MC)) -> Prog MC -> ExtractM (Prog MC)
forall a b. (a -> b) -> a -> b
$ Stms MC -> [FunDef MC] -> Prog MC
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms MC
consts' [FunDef MC]
funs'

extractMulticore :: Pass SOACS MC
extractMulticore :: Pass SOACS MC
extractMulticore =
  Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: String
passName = String
"extract multicore parallelism",
      passDescription :: String
passDescription = String
"Extract multicore parallelism",
      passFunction :: Prog SOACS -> PassM (Prog MC)
passFunction = Prog SOACS -> PassM (Prog MC)
transformProg
    }