{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Sequentialise any remaining SOACs.  It is very important that
-- this is run *after* any access-pattern-related optimisation,
-- because this pass will destroy information.
--
-- This pass conceptually contains three subpasses:
--
-- 1. Sequentialise 'Stream' operations, leaving other SOACs intact.
--
-- 2. Apply whole-program simplification.
--
-- 3. Sequentialise remaining SOACs.
--
-- This is because sequentialisation of streams creates many SOACs
-- operating on single-element arrays, which can be efficiently
-- simplified away, but only *before* they are turned into loops.  In
-- principle this pass could be split into multiple, but for now it is
-- kept together.
module Futhark.Optimise.Unstream (unstream) where

import Control.Monad.Reader
import Control.Monad.State
import Futhark.IR.Kernels
import Futhark.IR.Kernels.Simplify (simplifyKernels)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT

data Stage = SeqStreams | SeqAll

-- | The pass definition.
unstream :: Pass Kernels Kernels
unstream :: Pass Kernels Kernels
unstream =
  String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"unstream" String
"sequentialise remaining SOACs" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
    (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation (Stage -> Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
Stage -> Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise Stage
SeqStreams)
      (Prog Kernels -> PassM (Prog Kernels))
-> (Prog Kernels -> PassM (Prog Kernels))
-> Prog Kernels
-> PassM (Prog Kernels)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Prog Kernels -> PassM (Prog Kernels)
simplifyKernels
      (Prog Kernels -> PassM (Prog Kernels))
-> (Prog Kernels -> PassM (Prog Kernels))
-> Prog Kernels
-> PassM (Prog Kernels)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation (Stage -> Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
Stage -> Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise Stage
SeqAll)
  where
    optimise :: Stage -> Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise Stage
stage Scope Kernels
scope Stms Kernels
stms =
      (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels))
-> (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms Kernels)
 -> VNameSource -> (Stms Kernels, VNameSource))
-> State VNameSource (Stms Kernels)
-> VNameSource
-> (Stms Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stage
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stage
stage Stms Kernels
stms) Scope Kernels
scope

type UnstreamM = ReaderT (Scope Kernels) (State VNameSource)

optimiseStms :: Stage -> Stms Kernels -> UnstreamM (Stms Kernels)
optimiseStms :: Stage
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stage
stage Stms Kernels
stms =
  Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
    [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([[Stm Kernels]] -> [Stm Kernels])
-> [[Stm Kernels]]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm Kernels]] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Stage
-> Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm Stage
stage) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)

optimiseBody :: Stage -> Body Kernels -> UnstreamM (Body Kernels)
optimiseBody :: Stage -> Body Kernels -> UnstreamM (Body Kernels)
optimiseBody Stage
stage (Body () Stms Kernels
stms Result
res) =
  BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stage
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stage
stage Stms Kernels
stms ReaderT
  (Scope Kernels) (State VNameSource) (Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> UnstreamM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseKernelBody :: Stage -> KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody :: Stage -> KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody Stage
stage (KernelBody () Stms Kernels
stms [KernelResult]
res) =
  Scope Kernels
-> UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels))
-> UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
    BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()
      (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([[Stm Kernels]] -> [Stm Kernels])
-> [[Stm Kernels]]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm Kernels]] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Stage
-> Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm Stage
stage) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms))
      ReaderT
  (Scope Kernels)
  (State VNameSource)
  ([KernelResult] -> KernelBody Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [KernelResult]
-> UnstreamM (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult]
-> ReaderT (Scope Kernels) (State VNameSource) [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

optimiseLambda :: Stage -> Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda :: Stage -> Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda Stage
stage Lambda Kernels
lam = Scope Kernels
-> UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam) (UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels))
-> UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$ do
  Body Kernels
body <- Stage -> Body Kernels -> UnstreamM (Body Kernels)
optimiseBody Stage
stage (Body Kernels -> UnstreamM (Body Kernels))
-> Body Kernels -> UnstreamM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam
  Lambda Kernels -> UnstreamM (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda Kernels
lam {lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body}

sequentialise :: Stage -> SOAC Kernels -> Bool
sequentialise :: Stage -> SOAC Kernels -> Bool
sequentialise Stage
SeqStreams Stream {} = Bool
True
sequentialise Stage
SeqStreams SOAC Kernels
_ = Bool
False
sequentialise Stage
SeqAll SOAC Kernels
_ = Bool
True

optimiseStm :: Stage -> Stm Kernels -> UnstreamM [Stm Kernels]
optimiseStm :: Stage
-> Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm Stage
stage (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (OtherOp soac)))
  | Stage -> SOAC Kernels -> Bool
sequentialise Stage
stage SOAC Kernels
soac = do
    Stms Kernels
stms <- Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels ()
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT Kernels (State VNameSource)))
-> SOAC (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat SOAC (Lore (BinderT Kernels (State VNameSource)))
SOAC Kernels
soac
    ([[Stm Kernels]] -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
 -> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Stage
-> Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm Stage
stage) ([Stm Kernels]
 -> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms
  | Bool
otherwise = do
    -- Still sequentialise whatever's inside.
    Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> (SOAC Kernels -> ExpT Kernels) -> SOAC Kernels -> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp Kernels (SOAC Kernels) -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (HostOp Kernels (SOAC Kernels) -> ExpT Kernels)
-> (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels
-> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (SOAC Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> SOAC Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (SOAC Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise SOAC Kernels
soac)
  where
    optimise :: SOACMapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = SOACMapper Any Any (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
mapOnSOACLambda = Stage -> Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda Stage
stage}
optimiseStm Stage
stage (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp op))) =
  Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope Kernels) -> SegSpace -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op) (ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall a b. (a -> b) -> a -> b
$
    Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> (SegOp SegLevel Kernels -> ExpT Kernels)
-> SegOp SegLevel Kernels
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp Kernels (SOAC Kernels) -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (HostOp Kernels (SOAC Kernels) -> ExpT Kernels)
-> (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels
-> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (SegOp SegLevel Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  SegLevel
  Kernels
  Kernels
  (ReaderT (Scope Kernels) (State VNameSource))
-> SegOp SegLevel Kernels
-> ReaderT
     (Scope Kernels) (State VNameSource) (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  SegLevel
  Kernels
  Kernels
  (ReaderT (Scope Kernels) (State VNameSource))
optimise SegOp SegLevel Kernels
op)
  where
    optimise :: SegOpMapper
  SegLevel
  Kernels
  Kernels
  (ReaderT (Scope Kernels) (State VNameSource))
optimise =
      SegOpMapper
  SegLevel Any Any (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
mapOnSegOpBody = Stage -> KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody Stage
stage,
          mapOnSegOpLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
mapOnSegOpLambda = Stage -> Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda Stage
stage
        }
optimiseStm Stage
stage (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) =
  Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
  where
    optimise :: Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper {mapOnBody :: Scope Kernels -> Body Kernels -> UnstreamM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels
-> UnstreamM (Body Kernels) -> UnstreamM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (UnstreamM (Body Kernels) -> UnstreamM (Body Kernels))
-> (Body Kernels -> UnstreamM (Body Kernels))
-> Body Kernels
-> UnstreamM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stage -> Body Kernels -> UnstreamM (Body Kernels)
optimiseBody Stage
stage}