{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Sequentialise any remaining SOACs.  It is very important that
-- this is run *after* any access-pattern-related optimisation,
-- because this pass will destroy information.
--
-- This pass conceptually contains three subpasses:
--
-- 1. Sequentialise 'Stream' operations, leaving other SOACs intact.
--
-- 2. Apply whole-program simplification.
--
-- 3. Sequentialise remaining SOACs.
--
-- This is because sequentialisation of streams creates many SOACs
-- operating on single-element arrays, which can be efficiently
-- simplified away, but only *before* they are turned into loops.  In
-- principle this pass could be split into multiple, but for now it is
-- kept together.
module Futhark.Optimise.Unstream (unstreamKernels, unstreamMC) where

import Control.Monad.Reader
import Control.Monad.State
import Futhark.IR.Kernels
import qualified Futhark.IR.Kernels as Kernels
import Futhark.IR.Kernels.Simplify (simplifyKernels)
import Futhark.IR.MC
import qualified Futhark.IR.MC as MC
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT

-- | The pass for GPU kernels.
unstreamKernels :: Pass Kernels Kernels
unstreamKernels :: Pass Kernels Kernels
unstreamKernels = (Stage -> OnOp Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall lore.
ASTLore lore =>
(Stage -> OnOp lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
unstream Stage -> OnOp Kernels
onHostOp Prog Kernels -> PassM (Prog Kernels)
simplifyKernels

-- | The pass for multicore.
unstreamMC :: Pass MC MC
unstreamMC :: Pass MC MC
unstreamMC = (Stage -> OnOp MC) -> (Prog MC -> PassM (Prog MC)) -> Pass MC MC
forall lore.
ASTLore lore =>
(Stage -> OnOp lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
unstream Stage -> OnOp MC
onMCOp Prog MC -> PassM (Prog MC)
MC.simplifyProg

data Stage = SeqStreams | SeqAll

unstream ::
  ASTLore lore =>
  (Stage -> OnOp lore) ->
  (Prog lore -> PassM (Prog lore)) ->
  Pass lore lore
unstream :: (Stage -> OnOp lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
unstream Stage -> OnOp lore
onOp Prog lore -> PassM (Prog lore)
simplify =
  String
-> String -> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"unstream" String
"sequentialise remaining SOACs" ((Prog lore -> PassM (Prog lore)) -> Pass lore lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall a b. (a -> b) -> a -> b
$
    (Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation (Stage -> Scope lore -> Stms lore -> PassM (Stms lore)
optimise Stage
SeqStreams)
      (Prog lore -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog lore))
-> Prog lore
-> PassM (Prog lore)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Prog lore -> PassM (Prog lore)
simplify
      (Prog lore -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog lore))
-> Prog lore
-> PassM (Prog lore)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation (Stage -> Scope lore -> Stms lore -> PassM (Stms lore)
optimise Stage
SeqAll)
  where
    optimise :: Stage -> Scope lore -> Stms lore -> PassM (Stms lore)
optimise Stage
stage Scope lore
scope Stms lore
stms =
      (VNameSource -> (Stms lore, VNameSource)) -> PassM (Stms lore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms lore, VNameSource)) -> PassM (Stms lore))
-> (VNameSource -> (Stms lore, VNameSource)) -> PassM (Stms lore)
forall a b. (a -> b) -> a -> b
$
        State VNameSource (Stms lore)
-> VNameSource -> (Stms lore, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms lore)
 -> VNameSource -> (Stms lore, VNameSource))
-> State VNameSource (Stms lore)
-> VNameSource
-> (Stms lore, VNameSource)
forall a b. (a -> b) -> a -> b
$
          ReaderT (Scope lore) (State VNameSource) (Stms lore)
-> Scope lore -> State VNameSource (Stms lore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (OnOp lore
-> Stms lore
-> ReaderT (Scope lore) (State VNameSource) (Stms lore)
forall lore.
ASTLore lore =>
OnOp lore -> Stms lore -> UnstreamM lore (Stms lore)
optimiseStms (Stage -> OnOp lore
onOp Stage
stage) Stms lore
stms) Scope lore
scope

type UnstreamM lore = ReaderT (Scope lore) (State VNameSource)

type OnOp lore =
  Pattern lore -> StmAux (ExpDec lore) -> Op lore -> UnstreamM lore [Stm lore]

optimiseStms ::
  ASTLore lore =>
  OnOp lore ->
  Stms lore ->
  UnstreamM lore (Stms lore)
optimiseStms :: OnOp lore -> Stms lore -> UnstreamM lore (Stms lore)
optimiseStms OnOp lore
onOp Stms lore
stms =
  Scope lore
-> UnstreamM lore (Stms lore) -> UnstreamM lore (Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms) (UnstreamM lore (Stms lore) -> UnstreamM lore (Stms lore))
-> UnstreamM lore (Stms lore) -> UnstreamM lore (Stms lore)
forall a b. (a -> b) -> a -> b
$
    [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore)
-> ([[Stm lore]] -> [Stm lore]) -> [[Stm lore]] -> Stms lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm lore]] -> [Stm lore]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm lore]] -> Stms lore)
-> ReaderT (Scope lore) (State VNameSource) [[Stm lore]]
-> UnstreamM lore (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm lore -> ReaderT (Scope lore) (State VNameSource) [Stm lore])
-> [Stm lore]
-> ReaderT (Scope lore) (State VNameSource) [[Stm lore]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp lore
-> Stm lore -> ReaderT (Scope lore) (State VNameSource) [Stm lore]
forall lore.
ASTLore lore =>
OnOp lore -> Stm lore -> UnstreamM lore [Stm lore]
optimiseStm OnOp lore
onOp) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms)

optimiseBody ::
  ASTLore lore =>
  OnOp lore ->
  Body lore ->
  UnstreamM lore (Body lore)
optimiseBody :: OnOp lore -> Body lore -> UnstreamM lore (Body lore)
optimiseBody OnOp lore
onOp (Body BodyDec lore
aux Stms lore
stms Result
res) =
  BodyDec lore -> Stms lore -> Result -> Body lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec lore
aux (Stms lore -> Result -> Body lore)
-> ReaderT (Scope lore) (State VNameSource) (Stms lore)
-> ReaderT (Scope lore) (State VNameSource) (Result -> Body lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OnOp lore
-> Stms lore
-> ReaderT (Scope lore) (State VNameSource) (Stms lore)
forall lore.
ASTLore lore =>
OnOp lore -> Stms lore -> UnstreamM lore (Stms lore)
optimiseStms OnOp lore
onOp Stms lore
stms ReaderT (Scope lore) (State VNameSource) (Result -> Body lore)
-> ReaderT (Scope lore) (State VNameSource) Result
-> UnstreamM lore (Body lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope lore) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseKernelBody ::
  ASTLore lore =>
  OnOp lore ->
  KernelBody lore ->
  UnstreamM lore (KernelBody lore)
optimiseKernelBody :: OnOp lore -> KernelBody lore -> UnstreamM lore (KernelBody lore)
optimiseKernelBody OnOp lore
onOp (KernelBody BodyDec lore
attr Stms lore
stms [KernelResult]
res) =
  Scope lore
-> UnstreamM lore (KernelBody lore)
-> UnstreamM lore (KernelBody lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms) (UnstreamM lore (KernelBody lore)
 -> UnstreamM lore (KernelBody lore))
-> UnstreamM lore (KernelBody lore)
-> UnstreamM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$
    BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
attr
      (Stms lore -> [KernelResult] -> KernelBody lore)
-> ReaderT (Scope lore) (State VNameSource) (Stms lore)
-> ReaderT
     (Scope lore)
     (State VNameSource)
     ([KernelResult] -> KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore)
-> ([[Stm lore]] -> [Stm lore]) -> [[Stm lore]] -> Stms lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm lore]] -> [Stm lore]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm lore]] -> Stms lore)
-> ReaderT (Scope lore) (State VNameSource) [[Stm lore]]
-> ReaderT (Scope lore) (State VNameSource) (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm lore -> ReaderT (Scope lore) (State VNameSource) [Stm lore])
-> [Stm lore]
-> ReaderT (Scope lore) (State VNameSource) [[Stm lore]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp lore
-> Stm lore -> ReaderT (Scope lore) (State VNameSource) [Stm lore]
forall lore.
ASTLore lore =>
OnOp lore -> Stm lore -> UnstreamM lore [Stm lore]
optimiseStm OnOp lore
onOp) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms))
      ReaderT
  (Scope lore)
  (State VNameSource)
  ([KernelResult] -> KernelBody lore)
-> ReaderT (Scope lore) (State VNameSource) [KernelResult]
-> UnstreamM lore (KernelBody lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult]
-> ReaderT (Scope lore) (State VNameSource) [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

optimiseLambda ::
  ASTLore lore =>
  OnOp lore ->
  Lambda lore ->
  UnstreamM lore (Lambda lore)
optimiseLambda :: OnOp lore -> Lambda lore -> UnstreamM lore (Lambda lore)
optimiseLambda OnOp lore
onOp Lambda lore
lam = Scope lore
-> UnstreamM lore (Lambda lore) -> UnstreamM lore (Lambda lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (LParamInfo lore)] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param (LParamInfo lore)] -> Scope lore)
-> [Param (LParamInfo lore)] -> Scope lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) (UnstreamM lore (Lambda lore) -> UnstreamM lore (Lambda lore))
-> UnstreamM lore (Lambda lore) -> UnstreamM lore (Lambda lore)
forall a b. (a -> b) -> a -> b
$ do
  Body lore
body <- OnOp lore -> Body lore -> UnstreamM lore (Body lore)
forall lore.
ASTLore lore =>
OnOp lore -> Body lore -> UnstreamM lore (Body lore)
optimiseBody OnOp lore
onOp (Body lore -> UnstreamM lore (Body lore))
-> Body lore -> UnstreamM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  Lambda lore -> UnstreamM lore (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
lam {lambdaBody :: Body lore
lambdaBody = Body lore
body}

optimiseStm ::
  ASTLore lore =>
  OnOp lore ->
  Stm lore ->
  UnstreamM lore [Stm lore]
optimiseStm :: OnOp lore -> Stm lore -> UnstreamM lore [Stm lore]
optimiseStm OnOp lore
onOp (Let Pattern lore
pat StmAux (ExpDec lore)
aux (Op Op lore
op)) =
  OnOp lore
onOp Pattern lore
pat StmAux (ExpDec lore)
aux Op lore
op
optimiseStm OnOp lore
onOp (Let Pattern lore
pat StmAux (ExpDec lore)
aux ExpT lore
e) =
  Stm lore -> [Stm lore]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm lore -> [Stm lore])
-> ReaderT (Scope lore) (State VNameSource) (Stm lore)
-> UnstreamM lore [Stm lore]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern lore -> StmAux (ExpDec lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat StmAux (ExpDec lore)
aux (ExpT lore -> Stm lore)
-> ReaderT (Scope lore) (State VNameSource) (ExpT lore)
-> ReaderT (Scope lore) (State VNameSource) (Stm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper lore lore (ReaderT (Scope lore) (State VNameSource))
-> ExpT lore
-> ReaderT (Scope lore) (State VNameSource) (ExpT lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper lore lore (ReaderT (Scope lore) (State VNameSource))
optimise ExpT lore
e)
  where
    optimise :: Mapper lore lore (ReaderT (Scope lore) (State VNameSource))
optimise =
      Mapper lore lore (ReaderT (Scope lore) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope lore
-> Body lore
-> ReaderT (Scope lore) (State VNameSource) (Body lore)
mapOnBody = \Scope lore
scope ->
            Scope lore
-> ReaderT (Scope lore) (State VNameSource) (Body lore)
-> ReaderT (Scope lore) (State VNameSource) (Body lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope lore
scope (ReaderT (Scope lore) (State VNameSource) (Body lore)
 -> ReaderT (Scope lore) (State VNameSource) (Body lore))
-> (Body lore
    -> ReaderT (Scope lore) (State VNameSource) (Body lore))
-> Body lore
-> ReaderT (Scope lore) (State VNameSource) (Body lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnOp lore
-> Body lore
-> ReaderT (Scope lore) (State VNameSource) (Body lore)
forall lore.
ASTLore lore =>
OnOp lore -> Body lore -> UnstreamM lore (Body lore)
optimiseBody OnOp lore
onOp
        }

optimiseSegOp ::
  ASTLore lore =>
  OnOp lore ->
  SegOp lvl lore ->
  UnstreamM lore (SegOp lvl lore)
optimiseSegOp :: OnOp lore -> SegOp lvl lore -> UnstreamM lore (SegOp lvl lore)
optimiseSegOp OnOp lore
onOp SegOp lvl lore
op =
  Scope lore
-> UnstreamM lore (SegOp lvl lore)
-> UnstreamM lore (SegOp lvl lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope lore) -> SegSpace -> Scope lore
forall a b. (a -> b) -> a -> b
$ SegOp lvl lore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp lvl lore
op) (UnstreamM lore (SegOp lvl lore)
 -> UnstreamM lore (SegOp lvl lore))
-> UnstreamM lore (SegOp lvl lore)
-> UnstreamM lore (SegOp lvl lore)
forall a b. (a -> b) -> a -> b
$ SegOpMapper
  lvl lore lore (ReaderT (Scope lore) (State VNameSource))
-> SegOp lvl lore -> UnstreamM lore (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  lvl lore lore (ReaderT (Scope lore) (State VNameSource))
optimise SegOp lvl lore
op
  where
    optimise :: SegOpMapper
  lvl lore lore (ReaderT (Scope lore) (State VNameSource))
optimise =
      SegOpMapper lvl Any Any (ReaderT (Scope lore) (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpBody :: KernelBody lore
-> ReaderT (Scope lore) (State VNameSource) (KernelBody lore)
mapOnSegOpBody = OnOp lore
-> KernelBody lore
-> ReaderT (Scope lore) (State VNameSource) (KernelBody lore)
forall lore.
ASTLore lore =>
OnOp lore -> KernelBody lore -> UnstreamM lore (KernelBody lore)
optimiseKernelBody OnOp lore
onOp,
          mapOnSegOpLambda :: Lambda lore
-> ReaderT (Scope lore) (State VNameSource) (Lambda lore)
mapOnSegOpLambda = OnOp lore
-> Lambda lore
-> ReaderT (Scope lore) (State VNameSource) (Lambda lore)
forall lore.
ASTLore lore =>
OnOp lore -> Lambda lore -> UnstreamM lore (Lambda lore)
optimiseLambda OnOp lore
onOp
        }

onMCOp :: Stage -> OnOp MC
onMCOp :: Stage -> OnOp MC
onMCOp Stage
stage Pattern MC
pat StmAux (ExpDec MC)
aux (ParOp par_op op) = do
  Maybe (SegOp () MC)
par_op' <- (SegOp () MC
 -> ReaderT (Scope MC) (State VNameSource) (SegOp () MC))
-> Maybe (SegOp () MC)
-> ReaderT (Scope MC) (State VNameSource) (Maybe (SegOp () MC))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (OnOp MC
-> SegOp () MC
-> ReaderT (Scope MC) (State VNameSource) (SegOp () MC)
forall lore lvl.
ASTLore lore =>
OnOp lore -> SegOp lvl lore -> UnstreamM lore (SegOp lvl lore)
optimiseSegOp (Stage -> OnOp MC
onMCOp Stage
stage)) Maybe (SegOp () MC)
par_op
  SegOp () MC
op' <- OnOp MC
-> SegOp () MC
-> ReaderT (Scope MC) (State VNameSource) (SegOp () MC)
forall lore lvl.
ASTLore lore =>
OnOp lore -> SegOp lvl lore -> UnstreamM lore (SegOp lvl lore)
optimiseSegOp (Stage -> OnOp MC
onMCOp Stage
stage) SegOp () MC
op
  [Stm MC] -> UnstreamM MC [Stm MC]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern MC
pat StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
par_op' SegOp () MC
op']
onMCOp Stage
stage Pattern MC
pat StmAux (ExpDec MC)
aux (MC.OtherOp soac)
  | Stage -> SOAC MC -> Bool
forall lore. Stage -> SOAC lore -> Bool
sequentialise Stage
stage SOAC MC
soac = do
    Stms MC
stms <- Binder MC () -> ReaderT (Scope MC) (State VNameSource) (Stms MC)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder MC () -> ReaderT (Scope MC) (State VNameSource) (Stms MC))
-> Binder MC () -> ReaderT (Scope MC) (State VNameSource) (Stms MC)
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT MC (State VNameSource)))
-> SOAC (Lore (BinderT MC (State VNameSource))) -> Binder MC ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT MC (State VNameSource)))
Pattern MC
pat SOAC (Lore (BinderT MC (State VNameSource)))
SOAC MC
soac
    ([[Stm MC]] -> [Stm MC])
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> UnstreamM MC [Stm MC]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm MC]] -> [Stm MC]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
 -> UnstreamM MC [Stm MC])
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> UnstreamM MC [Stm MC]
forall a b. (a -> b) -> a -> b
$
      Scope MC
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms MC -> Scope MC
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms MC
stms) (ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
 -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]])
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
-> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall a b. (a -> b) -> a -> b
$
        (Stm MC -> UnstreamM MC [Stm MC])
-> [Stm MC] -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp MC -> Stm MC -> UnstreamM MC [Stm MC]
forall lore.
ASTLore lore =>
OnOp lore -> Stm lore -> UnstreamM lore [Stm lore]
optimiseStm (Stage -> OnOp MC
onMCOp Stage
stage)) ([Stm MC] -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]])
-> [Stm MC] -> ReaderT (Scope MC) (State VNameSource) [[Stm MC]]
forall a b. (a -> b) -> a -> b
$ Stms MC -> [Stm MC]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms MC
stms
  | Bool
otherwise =
    -- Still sequentialise whatever's inside.
    Stm MC -> [Stm MC]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm MC -> [Stm MC])
-> ReaderT (Scope MC) (State VNameSource) (Stm MC)
-> UnstreamM MC [Stm MC]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern MC
pat StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> (SOAC MC -> Exp MC) -> SOAC MC -> Stm MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MCOp MC (SOAC MC) -> Exp MC
forall lore. Op lore -> ExpT lore
Op (MCOp MC (SOAC MC) -> Exp MC)
-> (SOAC MC -> MCOp MC (SOAC MC)) -> SOAC MC -> Exp MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC MC -> MCOp MC (SOAC MC)
forall lore op. op -> MCOp lore op
MC.OtherOp (SOAC MC -> Stm MC)
-> ReaderT (Scope MC) (State VNameSource) (SOAC MC)
-> ReaderT (Scope MC) (State VNameSource) (Stm MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
-> SOAC MC -> ReaderT (Scope MC) (State VNameSource) (SOAC MC)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
optimise SOAC MC
soac)
  where
    optimise :: SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
optimise =
      SOACMapper Any Any (ReaderT (Scope MC) (State VNameSource))
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper
        { mapOnSOACLambda :: Lambda MC -> ReaderT (Scope MC) (State VNameSource) (Lambda MC)
mapOnSOACLambda = OnOp MC
-> Lambda MC -> ReaderT (Scope MC) (State VNameSource) (Lambda MC)
forall lore.
ASTLore lore =>
OnOp lore -> Lambda lore -> UnstreamM lore (Lambda lore)
optimiseLambda (Stage -> OnOp MC
onMCOp Stage
stage)
        }

sequentialise :: Stage -> SOAC lore -> Bool
sequentialise :: Stage -> SOAC lore -> Bool
sequentialise Stage
SeqStreams Stream {} = Bool
True
sequentialise Stage
SeqStreams SOAC lore
_ = Bool
False
sequentialise Stage
SeqAll SOAC lore
_ = Bool
True

onHostOp :: Stage -> OnOp Kernels
onHostOp :: Stage -> OnOp Kernels
onHostOp Stage
stage Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Kernels.OtherOp soac)
  | Stage -> SOAC Kernels -> Bool
forall lore. Stage -> SOAC lore -> Bool
sequentialise Stage
stage SOAC Kernels
soac = do
    Stms Kernels
stms <- Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels ()
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT Kernels (State VNameSource)))
-> SOAC (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat SOAC (Lore (BinderT Kernels (State VNameSource)))
SOAC Kernels
soac
    ([[Stm Kernels]] -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> UnstreamM Kernels [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
 -> UnstreamM Kernels [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> UnstreamM Kernels [Stm Kernels]
forall a b. (a -> b) -> a -> b
$
      Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
 -> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$
        (Stm Kernels -> UnstreamM Kernels [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (OnOp Kernels -> Stm Kernels -> UnstreamM Kernels [Stm Kernels]
forall lore.
ASTLore lore =>
OnOp lore -> Stm lore -> UnstreamM lore [Stm lore]
optimiseStm (Stage -> OnOp Kernels
onHostOp Stage
stage)) ([Stm Kernels]
 -> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms
  | Bool
otherwise =
    -- Still sequentialise whatever's inside.
    Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> UnstreamM Kernels [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Exp Kernels -> Stm Kernels)
-> (SOAC Kernels -> Exp Kernels) -> SOAC Kernels -> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp Kernels (SOAC Kernels) -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op (HostOp Kernels (SOAC Kernels) -> Exp Kernels)
-> (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels
-> Exp Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
Kernels.OtherOp (SOAC Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (SOAC Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> SOAC Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (SOAC Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise SOAC Kernels
soac)
  where
    optimise :: SOACMapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise =
      SOACMapper Any Any (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper
        { mapOnSOACLambda :: Lambda Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Lambda Kernels)
mapOnSOACLambda = OnOp Kernels
-> Lambda Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Lambda Kernels)
forall lore.
ASTLore lore =>
OnOp lore -> Lambda lore -> UnstreamM lore (Lambda lore)
optimiseLambda (Stage -> OnOp Kernels
onHostOp Stage
stage)
        }
onHostOp Stage
stage Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (SegOp op) =
  Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> UnstreamM Kernels [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Exp Kernels -> Stm Kernels)
-> (SegOp SegLevel Kernels -> Exp Kernels)
-> SegOp SegLevel Kernels
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp Kernels (SOAC Kernels) -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op (HostOp Kernels (SOAC Kernels) -> Exp Kernels)
-> (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels
-> Exp Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (SegOp SegLevel Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OnOp Kernels
-> SegOp SegLevel Kernels
-> ReaderT
     (Scope Kernels) (State VNameSource) (SegOp SegLevel Kernels)
forall lore lvl.
ASTLore lore =>
OnOp lore -> SegOp lvl lore -> UnstreamM lore (SegOp lvl lore)
optimiseSegOp (Stage -> OnOp Kernels
onHostOp Stage
stage) SegOp SegLevel Kernels
op)
onHostOp Stage
_ Pattern Kernels
pat StmAux (ExpDec Kernels)
aux Op Kernels
op = [Stm Kernels] -> UnstreamM Kernels [Stm Kernels]
forall (m :: * -> *) a. Monad m => a -> m a
return [Pattern Kernels
-> StmAux (ExpDec Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op Op Kernels
op]