{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Sequentialise any remaining SOACs.  It is very important that
-- this is run *after* any access-pattern-related optimisation,
-- because this pass will destroy information.
--
-- This pass conceptually contains three subpasses:
--
-- 1. Sequentialise 'Stream' operations, leaving other SOACs intact.
--
-- 2. Apply whole-program simplification.
--
-- 3. Sequentialise remaining SOACs.
--
-- This is because sequentialisation of streams creates many SOACs
-- operating on single-element arrays, which can be efficiently
-- simplified away, but only *before* they are turned into loops.  In
-- principle this pass could be split into multiple, but for now it is
-- kept together.
module Futhark.Optimise.Unstream (unstreamGPU, unstreamMC) where

import Control.Monad.Reader
import Control.Monad.State
import Futhark.IR.GPU
import qualified Futhark.IR.GPU as GPU
import Futhark.IR.GPU.Simplify (simplifyGPU)
import Futhark.IR.MC
import qualified Futhark.IR.MC as MC
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT

-- | The pass for GPU kernels.
unstreamGPU :: Pass GPU GPU
unstreamGPU :: Pass GPU GPU
unstreamGPU = (Stage -> OnOp GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall rep.
ASTRep rep =>
(Stage -> OnOp rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
unstream Stage -> OnOp GPU
onHostOp Prog GPU -> PassM (Prog GPU)
simplifyGPU

-- | The pass for multicore.
unstreamMC :: Pass MC MC
unstreamMC :: Pass MC MC
unstreamMC = (Stage -> OnOp MC) -> (Prog MC -> PassM (Prog MC)) -> Pass MC MC
forall rep.
ASTRep rep =>
(Stage -> OnOp rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
unstream Stage -> OnOp MC
onMCOp Prog MC -> PassM (Prog MC)
MC.simplifyProg

data Stage = SeqStreams | SeqAll

unstream ::
  ASTRep rep =>
  (Stage -> OnOp rep) ->
  (Prog rep -> PassM (Prog rep)) ->
  Pass rep rep
unstream :: (Stage -> OnOp rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
unstream Stage -> OnOp rep
onOp Prog rep -> PassM (Prog rep)
simplify =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"unstream" String
"sequentialise remaining SOACs" ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$
    (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation (Stage -> Scope rep -> Stms rep -> PassM (Stms rep)
optimise Stage
SeqStreams)
      (Prog rep -> PassM (Prog rep))
-> (Prog rep -> PassM (Prog rep)) -> Prog rep -> PassM (Prog rep)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Prog rep -> PassM (Prog rep)
simplify
      (Prog rep -> PassM (Prog rep))
-> (Prog rep -> PassM (Prog rep)) -> Prog rep -> PassM (Prog rep)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation (Stage -> Scope rep -> Stms rep -> PassM (Stms rep)
optimise Stage
SeqAll)
  where
    optimise :: Stage -> Scope rep -> Stms rep -> PassM (Stms rep)
optimise Stage
stage Scope rep
scope Stms rep
stms =
      (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep))
-> (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall a b. (a -> b) -> a -> b
$
        State VNameSource (Stms rep)
-> VNameSource -> (Stms rep, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms rep)
 -> VNameSource -> (Stms rep, VNameSource))
-> State VNameSource (Stms rep)
-> VNameSource
-> (Stms rep, VNameSource)
forall a b. (a -> b) -> a -> b
$
          ReaderT (Scope rep) (State VNameSource) (Stms rep)
-> Scope rep -> State VNameSource (Stms rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (OnOp rep
-> Stms rep -> ReaderT (Scope rep) (State VNameSource) (Stms rep)
forall rep.
ASTRep rep =>
OnOp rep -> Stms rep -> UnstreamM rep (Stms rep)
optimiseStms (Stage -> OnOp rep
onOp Stage
stage) Stms rep
stms) Scope rep
scope

type UnstreamM rep = ReaderT (Scope rep) (State VNameSource)

type OnOp rep =
  Pat (LetDec rep) -> StmAux (ExpDec rep) -> Op rep -> UnstreamM rep [Stm rep]

optimiseStms ::
  ASTRep rep =>
  OnOp rep ->
  Stms rep ->
  UnstreamM rep (Stms rep)
optimiseStms :: OnOp rep -> Stms rep -> UnstreamM rep (Stms rep)
optimiseStms OnOp rep
onOp Stms rep
stms =
  Scope rep -> UnstreamM rep (Stms rep) -> UnstreamM rep (Stms rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms) (UnstreamM rep (Stms rep) -> UnstreamM rep (Stms rep))
-> UnstreamM rep (Stms rep) -> UnstreamM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$
    [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep)
-> ([[Stm rep]] -> [Stm rep]) -> [[Stm rep]] -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm rep]] -> [Stm rep]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm rep]] -> Stms rep)
-> ReaderT (Scope rep) (State VNameSource) [[Stm rep]]
-> UnstreamM rep (Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm rep -> ReaderT (Scope rep) (State VNameSource) [Stm rep])
-> [Stm rep] -> ReaderT (Scope rep) (State VNameSource) [[Stm rep]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp rep
-> Stm rep -> ReaderT (Scope rep) (State VNameSource) [Stm rep]
forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm OnOp rep
onOp) (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)

optimiseBody ::
  ASTRep rep =>
  OnOp rep ->
  Body rep ->
  UnstreamM rep (Body rep)
optimiseBody :: OnOp rep -> Body rep -> UnstreamM rep (Body rep)
optimiseBody OnOp rep
onOp (Body BodyDec rep
aux Stms rep
stms Result
res) =
  BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
aux (Stms rep -> Result -> Body rep)
-> ReaderT (Scope rep) (State VNameSource) (Stms rep)
-> ReaderT (Scope rep) (State VNameSource) (Result -> Body rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OnOp rep
-> Stms rep -> ReaderT (Scope rep) (State VNameSource) (Stms rep)
forall rep.
ASTRep rep =>
OnOp rep -> Stms rep -> UnstreamM rep (Stms rep)
optimiseStms OnOp rep
onOp Stms rep
stms ReaderT (Scope rep) (State VNameSource) (Result -> Body rep)
-> ReaderT (Scope rep) (State VNameSource) Result
-> UnstreamM rep (Body rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope rep) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseKernelBody ::
  ASTRep rep =>
  OnOp rep ->
  KernelBody rep ->
  UnstreamM rep (KernelBody rep)
optimiseKernelBody :: OnOp rep -> KernelBody rep -> UnstreamM rep (KernelBody rep)
optimiseKernelBody OnOp rep
onOp (KernelBody BodyDec rep
attr Stms rep
stms [KernelResult]
res) =
  Scope rep
-> UnstreamM rep (KernelBody rep) -> UnstreamM rep (KernelBody rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms) (UnstreamM rep (KernelBody rep) -> UnstreamM rep (KernelBody rep))
-> UnstreamM rep (KernelBody rep) -> UnstreamM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$
    BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
attr
      (Stms rep -> [KernelResult] -> KernelBody rep)
-> ReaderT (Scope rep) (State VNameSource) (Stms rep)
-> ReaderT
     (Scope rep) (State VNameSource) ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep)
-> ([[Stm rep]] -> [Stm rep]) -> [[Stm rep]] -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm rep]] -> [Stm rep]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm rep]] -> Stms rep)
-> ReaderT (Scope rep) (State VNameSource) [[Stm rep]]
-> ReaderT (Scope rep) (State VNameSource) (Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm rep -> ReaderT (Scope rep) (State VNameSource) [Stm rep])
-> [Stm rep] -> ReaderT (Scope rep) (State VNameSource) [[Stm rep]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp rep
-> Stm rep -> ReaderT (Scope rep) (State VNameSource) [Stm rep]
forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm OnOp rep
onOp) (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
      ReaderT
  (Scope rep) (State VNameSource) ([KernelResult] -> KernelBody rep)
-> ReaderT (Scope rep) (State VNameSource) [KernelResult]
-> UnstreamM rep (KernelBody rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult]
-> ReaderT (Scope rep) (State VNameSource) [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

optimiseLambda ::
  ASTRep rep =>
  OnOp rep ->
  Lambda rep ->
  UnstreamM rep (Lambda rep)
optimiseLambda :: OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda OnOp rep
onOp Lambda rep
lam = Scope rep
-> UnstreamM rep (Lambda rep) -> UnstreamM rep (Lambda rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo rep)] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (LParamInfo rep)] -> Scope rep)
-> [Param (LParamInfo rep)] -> Scope rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) (UnstreamM rep (Lambda rep) -> UnstreamM rep (Lambda rep))
-> UnstreamM rep (Lambda rep) -> UnstreamM rep (Lambda rep)
forall a b. (a -> b) -> a -> b
$ do
  Body rep
body <- OnOp rep -> Body rep -> UnstreamM rep (Body rep)
forall rep.
ASTRep rep =>
OnOp rep -> Body rep -> UnstreamM rep (Body rep)
optimiseBody OnOp rep
onOp (Body rep -> UnstreamM rep (Body rep))
-> Body rep -> UnstreamM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  Lambda rep -> UnstreamM rep (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body}

optimiseStm ::
  ASTRep rep =>
  OnOp rep ->
  Stm rep ->
  UnstreamM rep [Stm rep]
optimiseStm :: OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm OnOp rep
onOp (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Op Op rep
op)) =
  OnOp rep
onOp Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
optimiseStm OnOp rep
onOp (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
  Stm rep -> [Stm rep]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep -> [Stm rep])
-> ReaderT (Scope rep) (State VNameSource) (Stm rep)
-> UnstreamM rep [Stm rep]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stm rep)
-> ReaderT (Scope rep) (State VNameSource) (Exp rep)
-> ReaderT (Scope rep) (State VNameSource) (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper rep rep (ReaderT (Scope rep) (State VNameSource))
-> Exp rep -> ReaderT (Scope rep) (State VNameSource) (Exp rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (ReaderT (Scope rep) (State VNameSource))
optimise Exp rep
e)
  where
    optimise :: Mapper rep rep (ReaderT (Scope rep) (State VNameSource))
optimise =
      Mapper rep rep (ReaderT (Scope rep) (State VNameSource))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope rep
-> Body rep -> ReaderT (Scope rep) (State VNameSource) (Body rep)
mapOnBody = \Scope rep
scope ->
            Scope rep
-> ReaderT (Scope rep) (State VNameSource) (Body rep)
-> ReaderT (Scope rep) (State VNameSource) (Body rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope (ReaderT (Scope rep) (State VNameSource) (Body rep)
 -> ReaderT (Scope rep) (State VNameSource) (Body rep))
-> (Body rep -> ReaderT (Scope rep) (State VNameSource) (Body rep))
-> Body rep
-> ReaderT (Scope rep) (State VNameSource) (Body rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnOp rep
-> Body rep -> ReaderT (Scope rep) (State VNameSource) (Body rep)
forall rep.
ASTRep rep =>
OnOp rep -> Body rep -> UnstreamM rep (Body rep)
optimiseBody OnOp rep
onOp
        }

optimiseSegOp ::
  ASTRep rep =>
  OnOp rep ->
  SegOp lvl rep ->
  UnstreamM rep (SegOp lvl rep)
optimiseSegOp :: OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp OnOp rep
onOp SegOp lvl rep
op =
  Scope rep
-> UnstreamM rep (SegOp lvl rep) -> UnstreamM rep (SegOp lvl rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace -> Scope rep) -> SegSpace -> Scope rep
forall a b. (a -> b) -> a -> b
$ SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op) (UnstreamM rep (SegOp lvl rep) -> UnstreamM rep (SegOp lvl rep))
-> UnstreamM rep (SegOp lvl rep) -> UnstreamM rep (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl rep rep (ReaderT (Scope rep) (State VNameSource))
-> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (ReaderT (Scope rep) (State VNameSource))
optimise SegOp lvl rep
op
  where
    optimise :: SegOpMapper lvl rep rep (ReaderT (Scope rep) (State VNameSource))
optimise =
      SegOpMapper lvl Any Any (ReaderT (Scope rep) (State VNameSource))
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpBody :: KernelBody rep
-> ReaderT (Scope rep) (State VNameSource) (KernelBody rep)
mapOnSegOpBody = OnOp rep
-> KernelBody rep
-> ReaderT (Scope rep) (State VNameSource) (KernelBody rep)
forall rep.
ASTRep rep =>
OnOp rep -> KernelBody rep -> UnstreamM rep (KernelBody rep)
optimiseKernelBody OnOp rep
onOp,
          mapOnSegOpLambda :: Lambda rep -> ReaderT (Scope rep) (State VNameSource) (Lambda rep)
mapOnSegOpLambda = OnOp rep
-> Lambda rep
-> ReaderT (Scope rep) (State VNameSource) (Lambda rep)
forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda OnOp rep
onOp
        }

onMCOp :: Stage -> OnOp MC
onMCOp :: Stage -> OnOp MC
onMCOp Stage
stage Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (ParOp par_op op) = do
  Maybe (SegOp () MC)
par_op' <- (SegOp () MC
 -> ReaderT (Scope MC) (State VNameSource) (SegOp () MC))
-> Maybe (SegOp () MC)
-> ReaderT (Scope MC) (State VNameSource) (Maybe (SegOp () MC))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (OnOp MC
-> SegOp () MC
-> ReaderT (Scope MC) (State VNameSource) (SegOp () MC)
forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp (Stage -> OnOp MC
onMCOp Stage
stage)) Maybe (SegOp () MC)
par_op
  SegOp () MC
op' <- OnOp MC
-> SegOp () MC
-> ReaderT (Scope MC) (State VNameSource) (SegOp () MC)
forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp (Stage -> OnOp MC
onMCOp Stage
stage) SegOp () MC
op
  [Stm MC] -> UnstreamM MC [Stm MC]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
par_op' SegOp () MC
op']
onMCOp Stage
stage Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (MC.OtherOp soac)
  | Stage -> SOAC MC -> Bool
forall rep. Stage -> SOAC rep -> Bool
sequentialise Stage
stage SOAC MC
soac = do
    Stms MC
stms <- Builder MC () -> ReaderT (Scope MC) (State VNameSource) (Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder MC () -> ReaderT (Scope MC) (State VNameSource) (Stms MC))
-> Builder MC ()
-> ReaderT (Scope MC) (State VNameSource) (Stms MC)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT MC (State VNameSource))))
-> SOAC (Rep (BuilderT MC (State VNameSource))) -> Builder MC ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT MC (State VNameSource))))
Pat (LetDec MC)
pat SOAC (Rep (BuilderT MC (State VNameSource)))
SOAC MC
soac
    ([[Stm MC]] -> [Stm MC])
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> UnstreamM MC [Stm MC]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm MC]] -> [Stm MC]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
 -> UnstreamM MC [Stm MC])
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> UnstreamM MC [Stm MC]
forall a b. (a -> b) -> a -> b
$
      Scope MC
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms MC -> Scope MC
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms MC
stms) (ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
 -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]])
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall a b. (a -> b) -> a -> b
$
        (Stm MC -> UnstreamM MC [Stm MC])
-> [Stm MC] -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp MC -> Stm MC -> UnstreamM MC [Stm MC]
forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm (Stage -> OnOp MC
onMCOp Stage
stage)) ([Stm MC] -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]])
-> [Stm MC] -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall a b. (a -> b) -> a -> b
$ Stms MC -> [Stm MC]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms MC
stms
  | Bool
otherwise =
    -- Still sequentialise whatever's inside.
    Stm MC -> [Stm MC]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm MC -> [Stm MC])
-> ReaderT (Scope MC) (State VNameSource) (Stm MC)
-> UnstreamM MC [Stm MC]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> (SOAC MC -> Exp MC) -> SOAC MC -> Stm MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MCOp MC (SOAC MC) -> Exp MC
forall rep. Op rep -> Exp rep
Op (MCOp MC (SOAC MC) -> Exp MC)
-> (SOAC MC -> MCOp MC (SOAC MC)) -> SOAC MC -> Exp MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC MC -> MCOp MC (SOAC MC)
forall rep op. op -> MCOp rep op
MC.OtherOp (SOAC MC -> Stm MC)
-> ReaderT (Scope MC) (State VNameSource) (SOAC MC)
-> ReaderT (Scope MC) (State VNameSource) (Stm MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
-> SOAC MC -> ReaderT (Scope MC) (State VNameSource) (SOAC MC)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
optimise SOAC MC
soac)
  where
    optimise :: SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
optimise =
      SOACMapper Any Any (ReaderT (Scope MC) (State VNameSource))
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper
        { mapOnSOACLambda :: Lambda MC -> ReaderT (Scope MC) (State VNameSource) (Lambda MC)
mapOnSOACLambda = OnOp MC
-> Lambda MC -> ReaderT (Scope MC) (State VNameSource) (Lambda MC)
forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda (Stage -> OnOp MC
onMCOp Stage
stage)
        }

sequentialise :: Stage -> SOAC rep -> Bool
sequentialise :: Stage -> SOAC rep -> Bool
sequentialise Stage
SeqStreams Stream {} = Bool
True
sequentialise Stage
SeqStreams SOAC rep
_ = Bool
False
sequentialise Stage
SeqAll SOAC rep
_ = Bool
True

onHostOp :: Stage -> OnOp GPU
onHostOp :: Stage -> OnOp GPU
onHostOp Stage
stage Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (GPU.OtherOp soac)
  | Stage -> SOAC GPU -> Bool
forall rep. Stage -> SOAC rep -> Bool
sequentialise Stage
stage SOAC GPU
soac = do
    Stms GPU
stms <- Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU ()
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT GPU (State VNameSource))))
-> SOAC (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT GPU (State VNameSource))))
Pat (LetDec GPU)
pat SOAC (Rep (BuilderT GPU (State VNameSource)))
SOAC GPU
soac
    ([[Stm GPU]] -> [Stm GPU])
-> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
-> UnstreamM GPU [Stm GPU]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm GPU]] -> [Stm GPU]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
 -> UnstreamM GPU [Stm GPU])
-> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
-> UnstreamM GPU [Stm GPU]
forall a b. (a -> b) -> a -> b
$
      Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
-> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
 -> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]])
-> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
-> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
forall a b. (a -> b) -> a -> b
$
        (Stm GPU -> UnstreamM GPU [Stm GPU])
-> [Stm GPU] -> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp GPU -> Stm GPU -> UnstreamM GPU [Stm GPU]
forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm (Stage -> OnOp GPU
onHostOp Stage
stage)) ([Stm GPU] -> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]])
-> [Stm GPU] -> ReaderT (Scope GPU) (State VNameSource) [[Stm GPU]]
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
  | Bool
otherwise =
    -- Still sequentialise whatever's inside.
    Stm GPU -> [Stm GPU]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> [Stm GPU])
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU)
-> UnstreamM GPU [Stm GPU]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU)
-> (SOAC GPU -> Exp GPU) -> SOAC GPU -> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (SOAC GPU -> HostOp GPU (SOAC GPU)) -> SOAC GPU -> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC GPU -> HostOp GPU (SOAC GPU)
forall rep op. op -> HostOp rep op
GPU.OtherOp (SOAC GPU -> Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (SOAC GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> SOAC GPU -> ReaderT (Scope GPU) (State VNameSource) (SOAC GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise SOAC GPU
soac)
  where
    optimise :: SOACMapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise =
      SOACMapper Any Any (ReaderT (Scope GPU) (State VNameSource))
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper
        { mapOnSOACLambda :: Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
mapOnSOACLambda = OnOp GPU
-> Lambda GPU
-> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda (Stage -> OnOp GPU
onHostOp Stage
stage)
        }
onHostOp Stage
stage Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (SegOp op) =
  Stm GPU -> [Stm GPU]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> [Stm GPU])
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU)
-> UnstreamM GPU [Stm GPU]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU)
-> (SegOp SegLevel GPU -> Exp GPU) -> SegOp SegLevel GPU -> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (SegOp SegLevel GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OnOp GPU
-> SegOp SegLevel GPU
-> ReaderT (Scope GPU) (State VNameSource) (SegOp SegLevel GPU)
forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp (Stage -> OnOp GPU
onHostOp Stage
stage) SegOp SegLevel GPU
op)
onHostOp Stage
_ Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Op GPU
op = [Stm GPU] -> UnstreamM GPU [Stm GPU]
forall (m :: * -> *) a. Monad m => a -> m a
return [Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op Op GPU
op]