{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module implements common-subexpression elimination.  This
-- module does not actually remove the duplicate, but only replaces
-- one with a diference to the other.  E.g:
--
-- @
--   let a = x + y
--   let b = x + y
-- @
--
-- becomes:
--
-- @
--   let a = x + y
--   let b = a
-- @
--
-- After which copy propagation in the simplifier will actually remove
-- the definition of @b@.
--
-- Our CSE is still rather stupid.  No normalisation is performed, so
-- the expressions @x+y@ and @y+x@ will be considered distinct.
-- Furthermore, no expression with its own binding will be considered
-- equal to any other, since the variable names will be distinct.
-- This affects SOACs in particular.
module Futhark.Optimise.CSE
  ( performCSE,
    performCSEOnFunDef,
    performCSEOnStms,
    CSEInOp,
  )
where

import Control.Monad.Reader
import qualified Data.Map.Strict as M
import Futhark.Analysis.Alias
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    mkStmsAliases,
    removeFunDefAliases,
    removeProgAliases,
    removeStmAliases,
  )
import qualified Futhark.IR.Kernels.Kernel as Kernel
import qualified Futhark.IR.MC as MC
import qualified Futhark.IR.Mem as Memory
import Futhark.IR.Prop.Aliases
import qualified Futhark.IR.SOACS.SOAC as SOAC
import Futhark.Pass
import Futhark.Transform.Substitute

consumedInStms :: Aliased lore => Stms lore -> Names
consumedInStms :: Stms lore -> Names
consumedInStms = ([Names], Names) -> Names
forall a b. (a, b) -> b
snd (([Names], Names) -> Names)
-> (Stms lore -> ([Names], Names)) -> Stms lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms lore -> [SubExp] -> ([Names], Names))
-> [SubExp] -> Stms lore -> ([Names], Names)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stms lore -> [SubExp] -> ([Names], Names)
forall lore.
Aliased lore =>
Stms lore -> [SubExp] -> ([Names], Names)
mkStmsAliases []

-- | Perform CSE on every function in a program.
--
-- If the boolean argument is false, the pass will not perform CSE on
-- expressions producing arrays. This should be disabled when the lore has
-- memory information, since at that point arrays have identity beyond their
-- value.
performCSE ::
  ( ASTLore lore,
    CanBeAliased (Op lore),
    CSEInOp (OpWithAliases (Op lore))
  ) =>
  Bool ->
  Pass lore lore
performCSE :: Bool -> Pass lore lore
performCSE Bool
cse_arrays =
  String
-> String -> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"CSE" String
"Combine common subexpressions." ((Prog lore -> PassM (Prog lore)) -> Pass lore lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall a b. (a -> b) -> a -> b
$
    (Prog (Aliases lore) -> Prog lore)
-> PassM (Prog (Aliases lore)) -> PassM (Prog lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases lore) -> Prog lore
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases
      (PassM (Prog (Aliases lore)) -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog (Aliases lore)))
-> Prog lore
-> PassM (Prog lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (Aliases lore) -> PassM (Stms (Aliases lore)))
-> (Stms (Aliases lore)
    -> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore)))
-> Prog (Aliases lore)
-> PassM (Prog (Aliases lore))
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms (Aliases lore) -> PassM (Stms (Aliases lore))
forall (f :: * -> *) lore.
(Applicative f, ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Stms lore -> f (Stms lore)
onConsts Stms (Aliases lore)
-> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore))
forall (f :: * -> *) lore p.
(Applicative f, ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
p -> FunDef lore -> f (FunDef lore)
onFun
      (Prog (Aliases lore) -> PassM (Prog (Aliases lore)))
-> (Prog lore -> Prog (Aliases lore))
-> Prog lore
-> PassM (Prog (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog lore -> Prog (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
aliasAnalysis
  where
    onConsts :: Stms lore -> f (Stms lore)
onConsts Stms lore
stms =
      Stms lore -> f (Stms lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms lore -> f (Stms lore)) -> Stms lore -> f (Stms lore)
forall a b. (a -> b) -> a -> b
$
        (Stms lore, ()) -> Stms lore
forall a b. (a, b) -> a
fst ((Stms lore, ()) -> Stms lore) -> (Stms lore, ()) -> Stms lore
forall a b. (a -> b) -> a -> b
$
          Reader (CSEState lore) (Stms lore, ())
-> CSEState lore -> (Stms lore, ())
forall r a. Reader r a -> r -> a
runReader
            (Names
-> [Stm lore]
-> CSEM lore ()
-> Reader (CSEState lore) (Stms lore, ())
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
stms) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) (() -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
            (Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays)
    onFun :: p -> FunDef lore -> f (FunDef lore)
onFun p
_ = FunDef lore -> f (FunDef lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef lore -> f (FunDef lore))
-> (FunDef lore -> FunDef lore) -> FunDef lore -> f (FunDef lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef lore -> FunDef lore
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays

-- | Perform CSE on a single function.
--
-- If the boolean argument is false, the pass will not perform CSE on
-- expressions producing arrays. This should be disabled when the lore has
-- memory information, since at that point arrays have identity beyond their
-- value.
performCSEOnFunDef ::
  ( ASTLore lore,
    CanBeAliased (Op lore),
    CSEInOp (OpWithAliases (Op lore))
  ) =>
  Bool ->
  FunDef lore ->
  FunDef lore
performCSEOnFunDef :: Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef Bool
cse_arrays =
  FunDef (Aliases lore) -> FunDef lore
forall lore.
CanBeAliased (Op lore) =>
FunDef (Aliases lore) -> FunDef lore
removeFunDefAliases (FunDef (Aliases lore) -> FunDef lore)
-> (FunDef lore -> FunDef (Aliases lore))
-> FunDef lore
-> FunDef lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef (Aliases lore) -> FunDef (Aliases lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays (FunDef (Aliases lore) -> FunDef (Aliases lore))
-> (FunDef lore -> FunDef (Aliases lore))
-> FunDef lore
-> FunDef (Aliases lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef lore -> FunDef (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
FunDef lore -> FunDef (Aliases lore)
analyseFun

-- | Perform CSE on some statements.
--
-- If the boolean argument is false, the pass will not perform CSE on
-- expressions producing arrays. This should be disabled when the lore has
-- memory information, since at that point arrays have identity beyond their
-- value.
performCSEOnStms ::
  ( ASTLore lore,
    CanBeAliased (Op lore),
    CSEInOp (OpWithAliases (Op lore))
  ) =>
  Bool ->
  Stms lore ->
  Stms lore
performCSEOnStms :: Bool -> Stms lore -> Stms lore
performCSEOnStms Bool
cse_arrays =
  (Stm (Aliases lore) -> Stm lore)
-> Seq (Stm (Aliases lore)) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases (Seq (Stm (Aliases lore)) -> Stms lore)
-> (Stms lore -> Seq (Stm (Aliases lore)))
-> Stms lore
-> Stms lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Stm (Aliases lore)) -> Seq (Stm (Aliases lore))
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Stms lore -> Stms lore
f (Seq (Stm (Aliases lore)) -> Seq (Stm (Aliases lore)))
-> (Stms lore -> Seq (Stm (Aliases lore)))
-> Stms lore
-> Seq (Stm (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Seq (Stm (Aliases lore)), AliasesAndConsumed)
-> Seq (Stm (Aliases lore))
forall a b. (a, b) -> a
fst ((Seq (Stm (Aliases lore)), AliasesAndConsumed)
 -> Seq (Stm (Aliases lore)))
-> (Stms lore -> (Seq (Stm (Aliases lore)), AliasesAndConsumed))
-> Stms lore
-> Seq (Stm (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable
-> Stms lore -> (Seq (Stm (Aliases lore)), AliasesAndConsumed)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
analyseStms AliasTable
forall a. Monoid a => a
mempty
  where
    f :: Stms lore -> Stms lore
f Stms lore
stms =
      (Stms lore, ()) -> Stms lore
forall a b. (a, b) -> a
fst ((Stms lore, ()) -> Stms lore) -> (Stms lore, ()) -> Stms lore
forall a b. (a -> b) -> a -> b
$
        Reader (CSEState lore) (Stms lore, ())
-> CSEState lore -> (Stms lore, ())
forall r a. Reader r a -> r -> a
runReader
          ( Names
-> [Stm lore]
-> CSEM lore ()
-> Reader (CSEState lore) (Stms lore, ())
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms
              (Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
stms)
              (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms)
              (() -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
          )
          (Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays)

cseInFunDef ::
  (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
  Bool ->
  FunDef lore ->
  FunDef lore
cseInFunDef :: Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays FunDef lore
fundec =
  FunDef lore
fundec
    { funDefBody :: BodyT lore
funDefBody =
        Reader (CSEState lore) (BodyT lore) -> CSEState lore -> BodyT lore
forall r a. Reader r a -> r -> a
runReader ([Diet] -> BodyT lore -> Reader (CSEState lore) (BodyT lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds (BodyT lore -> Reader (CSEState lore) (BodyT lore))
-> BodyT lore -> Reader (CSEState lore) (BodyT lore)
forall a b. (a -> b) -> a -> b
$ FunDef lore -> BodyT lore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef lore
fundec) (CSEState lore -> BodyT lore) -> CSEState lore -> BodyT lore
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays
    }
  where
    -- XXX: we treat every result as a consumption here, because we
    -- our core language is not strong enough to fully capture the
    -- aliases we want, so we are turning some parts off (see #803,
    -- #1241, and the related comment in TypeCheck.hs).  This is not a
    -- practical problem while we still perform such aggressive
    -- inlining.
    ds :: [Diet]
ds = (RetType lore -> Diet) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map RetType lore -> Diet
forall t. DeclExtTyped t => t -> Diet
retDiet ([RetType lore] -> [Diet]) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> a -> b
$ FunDef lore -> [RetType lore]
forall lore. FunDef lore -> [RetType lore]
funDefRetType FunDef lore
fundec
    retDiet :: t -> Diet
retDiet t
t
      | TypeBase ExtShape Uniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase ExtShape Uniqueness -> Bool)
-> TypeBase ExtShape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ t -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf t
t = Diet
Observe
      | Bool
otherwise = Diet
Consume

type CSEM lore = Reader (CSEState lore)

cseInBody ::
  (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
  [Diet] ->
  Body lore ->
  CSEM lore (Body lore)
cseInBody :: [Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds (Body BodyDec lore
bodydec Stms lore
stms [SubExp]
res) = do
  (Stms lore
stms', [SubExp]
res') <-
    Names
-> [Stm lore]
-> CSEM lore [SubExp]
-> CSEM lore (Stms lore, [SubExp])
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Names
res_cons Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
stms_cons) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) (CSEM lore [SubExp] -> CSEM lore (Stms lore, [SubExp]))
-> CSEM lore [SubExp] -> CSEM lore (Stms lore, [SubExp])
forall a b. (a -> b) -> a -> b
$ do
      CSEState (ExpressionSubstitutions lore
_, NameSubstitutions
nsubsts) Bool
_ <- ReaderT (CSEState lore) Identity (CSEState lore)
forall r (m :: * -> *). MonadReader r m => m r
ask
      [SubExp] -> CSEM lore [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> CSEM lore [SubExp]) -> [SubExp] -> CSEM lore [SubExp]
forall a b. (a -> b) -> a -> b
$ NameSubstitutions -> [SubExp] -> [SubExp]
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts [SubExp]
res
  Body lore -> CSEM lore (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> CSEM lore (Body lore))
-> Body lore -> CSEM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
bodydec Stms lore
stms' [SubExp]
res'
  where
    ([Names]
res_als, Names
stms_cons) = Stms lore -> [SubExp] -> ([Names], Names)
forall lore.
Aliased lore =>
Stms lore -> [SubExp] -> ([Names], Names)
mkStmsAliases Stms lore
stms [SubExp]
res
    res_cons :: Names
res_cons = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Diet -> Names -> Names) -> [Diet] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Diet -> Names -> Names
forall p. Monoid p => Diet -> p -> p
consumeResult [Diet]
ds [Names]
res_als
    consumeResult :: Diet -> p -> p
consumeResult Diet
Consume p
als = p
als
    consumeResult Diet
_ p
_ = p
forall a. Monoid a => a
mempty

cseInLambda ::
  (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
  Lambda lore ->
  CSEM lore (Lambda lore)
cseInLambda :: Lambda lore -> CSEM lore (Lambda lore)
cseInLambda Lambda lore
lam = do
  Body lore
body' <- [Diet] -> Body lore -> CSEM lore (Body lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) ([Type] -> [Diet]) -> [Type] -> [Diet]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam) (Body lore -> CSEM lore (Body lore))
-> Body lore -> CSEM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  Lambda lore -> CSEM lore (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
lam {lambdaBody :: Body lore
lambdaBody = Body lore
body'}

cseInStms ::
  (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
  Names ->
  [Stm lore] ->
  CSEM lore a ->
  CSEM lore (Stms lore, a)
cseInStms :: Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms Names
_ [] CSEM lore a
m = do
  a
a <- CSEM lore a
m
  (Stms lore, a) -> CSEM lore (Stms lore, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore
forall a. Monoid a => a
mempty, a
a)
cseInStms Names
consumed (Stm lore
bnd : [Stm lore]
bnds) CSEM lore a
m =
  Names
-> Stm lore
-> ([Stm lore] -> CSEM lore (Stms lore, a))
-> CSEM lore (Stms lore, a)
forall lore a.
ASTLore lore =>
Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a
cseInStm Names
consumed Stm lore
bnd (([Stm lore] -> CSEM lore (Stms lore, a))
 -> CSEM lore (Stms lore, a))
-> ([Stm lore] -> CSEM lore (Stms lore, a))
-> CSEM lore (Stms lore, a)
forall a b. (a -> b) -> a -> b
$ \[Stm lore]
bnd' -> do
    (Stms lore
bnds', a
a) <- Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms Names
consumed [Stm lore]
bnds CSEM lore a
m
    [Stm lore]
bnd'' <- (Stm lore -> ReaderT (CSEState lore) Identity (Stm lore))
-> [Stm lore] -> ReaderT (CSEState lore) Identity [Stm lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
nestedCSE [Stm lore]
bnd'
    (Stms lore, a) -> CSEM lore (Stms lore, a)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
bnd'' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
bnds', a
a)
  where
    nestedCSE :: Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
nestedCSE Stm lore
bnd' = do
      let ds :: [Diet]
ds = (PatElemT (LetDec lore) -> Diet)
-> [PatElemT (LetDec lore)] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec lore) -> Diet
forall dec. PatElemT dec -> Diet
patElemDiet ([PatElemT (LetDec lore)] -> [Diet])
-> [PatElemT (LetDec lore)] -> [Diet]
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT (LetDec lore) -> [PatElemT (LetDec lore)])
-> PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd'
      Exp lore
e <- Mapper lore lore (ReaderT (CSEState lore) Identity)
-> Exp lore -> ReaderT (CSEState lore) Identity (Exp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ([Diet] -> Mapper lore lore (ReaderT (CSEState lore) Identity)
forall tlore.
(ASTLore tlore, Aliased tlore, CSEInOp (Op tlore)) =>
[Diet] -> Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
cse [Diet]
ds) (Exp lore -> ReaderT (CSEState lore) Identity (Exp lore))
-> Exp lore -> ReaderT (CSEState lore) Identity (Exp lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd'
      Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Stm lore
bnd' {stmExp :: Exp lore
stmExp = Exp lore
e}

    cse :: [Diet] -> Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
cse [Diet]
ds =
      Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope tlore
-> Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore)
mapOnBody = (Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
-> Scope tlore
-> Body tlore
-> ReaderT (CSEState tlore) Identity (Body tlore)
forall a b. a -> b -> a
const ((Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
 -> Scope tlore
 -> Body tlore
 -> ReaderT (CSEState tlore) Identity (Body tlore))
-> (Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
-> Scope tlore
-> Body tlore
-> ReaderT (CSEState tlore) Identity (Body tlore)
forall a b. (a -> b) -> a -> b
$ [Diet]
-> Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds,
          mapOnOp :: Op tlore -> ReaderT (CSEState tlore) Identity (Op tlore)
mapOnOp = Op tlore -> ReaderT (CSEState tlore) Identity (Op tlore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp
        }

    patElemDiet :: PatElemT dec -> Diet
patElemDiet PatElemT dec
pe
      | PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Diet
Consume
      | Bool
otherwise = Diet
Observe

cseInStm ::
  ASTLore lore =>
  Names ->
  Stm lore ->
  ([Stm lore] -> CSEM lore a) ->
  CSEM lore a
cseInStm :: Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a
cseInStm Names
consumed (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) Exp lore
e) [Stm lore] -> CSEM lore a
m = do
  CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays <- ReaderT (CSEState lore) Identity (CSEState lore)
forall r (m :: * -> *). MonadReader r m => m r
ask
  let e' :: Exp lore
e' = NameSubstitutions -> Exp lore -> Exp lore
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Exp lore
e
      pat' :: Pattern lore
pat' = NameSubstitutions -> Pattern lore -> Pattern lore
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Pattern lore
pat
  if (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> PatElemT (LetDec lore) -> Bool
forall dec. Typed dec => Bool -> PatElemT dec -> Bool
bad Bool
cse_arrays) ([PatElemT (LetDec lore)] -> Bool)
-> [PatElemT (LetDec lore)] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat
    then [Stm lore] -> CSEM lore a
m [Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat' (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) Exp lore
e']
    else case (ExpDec lore, Exp lore)
-> ExpressionSubstitutions lore -> Maybe (Pattern lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ExpDec lore
edec, Exp lore
e') ExpressionSubstitutions lore
esubsts of
      Just Pattern lore
subpat ->
        (CSEState lore -> CSEState lore) -> CSEM lore a -> CSEM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pattern lore -> Pattern lore -> CSEState lore -> CSEState lore
forall dec lore.
PatternT dec -> PatternT dec -> CSEState lore -> CSEState lore
addNameSubst Pattern lore
pat' Pattern lore
subpat) (CSEM lore a -> CSEM lore a) -> CSEM lore a -> CSEM lore a
forall a b. (a -> b) -> a -> b
$ do
          let lets :: [Stm lore]
lets =
                [ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
patElem']) (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
patElem
                  | (VName
name, PatElemT (LetDec lore)
patElem) <- [VName]
-> [PatElemT (LetDec lore)] -> [(VName, PatElemT (LetDec lore))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat') ([PatElemT (LetDec lore)] -> [(VName, PatElemT (LetDec lore))])
-> [PatElemT (LetDec lore)] -> [(VName, PatElemT (LetDec lore))]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
subpat,
                    let patElem' :: PatElemT (LetDec lore)
patElem' = PatElemT (LetDec lore)
patElem {patElemName :: VName
patElemName = VName
name}
                ]
          [Stm lore] -> CSEM lore a
m [Stm lore]
lets
      Maybe (Pattern lore)
_ ->
        (CSEState lore -> CSEState lore) -> CSEM lore a -> CSEM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pattern lore
-> ExpDec lore -> Exp lore -> CSEState lore -> CSEState lore
forall lore.
ASTLore lore =>
Pattern lore
-> ExpDec lore -> Exp lore -> CSEState lore -> CSEState lore
addExpSubst Pattern lore
pat' ExpDec lore
edec Exp lore
e') (CSEM lore a -> CSEM lore a) -> CSEM lore a -> CSEM lore a
forall a b. (a -> b) -> a -> b
$
          [Stm lore] -> CSEM lore a
m [Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat' (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) Exp lore
e']
  where
    bad :: Bool -> PatElemT dec -> Bool
bad Bool
cse_arrays PatElemT dec
pe
      | Mem {} <- PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe = Bool
True
      | Array {} <- PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe, Bool -> Bool
not Bool
cse_arrays = Bool
True
      | PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Bool
True
      | Bool
otherwise = Bool
False

type ExpressionSubstitutions lore =
  M.Map
    (ExpDec lore, Exp lore)
    (Pattern lore)

type NameSubstitutions = M.Map VName VName

data CSEState lore = CSEState
  { CSEState lore -> (ExpressionSubstitutions lore, NameSubstitutions)
_cseSubstitutions :: (ExpressionSubstitutions lore, NameSubstitutions),
    CSEState lore -> Bool
_cseArrays :: Bool
  }

newCSEState :: Bool -> CSEState lore
newCSEState :: Bool -> CSEState lore
newCSEState = (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState (ExpressionSubstitutions lore
forall k a. Map k a
M.empty, NameSubstitutions
forall k a. Map k a
M.empty)

mkSubsts :: PatternT dec -> PatternT dec -> M.Map VName VName
mkSubsts :: PatternT dec -> PatternT dec -> NameSubstitutions
mkSubsts PatternT dec
pat PatternT dec
vs = [(VName, VName)] -> NameSubstitutions
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> NameSubstitutions)
-> [(VName, VName)] -> NameSubstitutions
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT dec
pat) (PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT dec
vs)

addNameSubst :: PatternT dec -> PatternT dec -> CSEState lore -> CSEState lore
addNameSubst :: PatternT dec -> PatternT dec -> CSEState lore -> CSEState lore
addNameSubst PatternT dec
pat PatternT dec
subpat (CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState (ExpressionSubstitutions lore
esubsts, PatternT dec -> PatternT dec -> NameSubstitutions
forall dec. PatternT dec -> PatternT dec -> NameSubstitutions
mkSubsts PatternT dec
pat PatternT dec
subpat NameSubstitutions -> NameSubstitutions -> NameSubstitutions
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` NameSubstitutions
nsubsts) Bool
cse_arrays

addExpSubst ::
  ASTLore lore =>
  Pattern lore ->
  ExpDec lore ->
  Exp lore ->
  CSEState lore ->
  CSEState lore
addExpSubst :: Pattern lore
-> ExpDec lore -> Exp lore -> CSEState lore -> CSEState lore
addExpSubst Pattern lore
pat ExpDec lore
edec Exp lore
e (CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState ((ExpDec lore, Exp lore)
-> Pattern lore
-> ExpressionSubstitutions lore
-> ExpressionSubstitutions lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (ExpDec lore
edec, Exp lore
e) Pattern lore
pat ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays

-- | The operations that permit CSE.
class CSEInOp op where
  -- | Perform CSE within any nested expressions.
  cseInOp :: op -> CSEM lore op

instance CSEInOp () where
  cseInOp :: () -> CSEM lore ()
cseInOp () = () -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

subCSE :: CSEM lore r -> CSEM otherlore r
subCSE :: CSEM lore r -> CSEM otherlore r
subCSE CSEM lore r
m = do
  CSEState (ExpressionSubstitutions otherlore, NameSubstitutions)
_ Bool
cse_arrays <- ReaderT (CSEState otherlore) Identity (CSEState otherlore)
forall r (m :: * -> *). MonadReader r m => m r
ask
  r -> CSEM otherlore r
forall (m :: * -> *) a. Monad m => a -> m a
return (r -> CSEM otherlore r) -> r -> CSEM otherlore r
forall a b. (a -> b) -> a -> b
$ CSEM lore r -> CSEState lore -> r
forall r a. Reader r a -> r -> a
runReader CSEM lore r
m (CSEState lore -> r) -> CSEState lore -> r
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays

instance
  ( ASTLore lore,
    Aliased lore,
    CSEInOp (Op lore),
    CSEInOp op
  ) =>
  CSEInOp (Kernel.HostOp lore op)
  where
  cseInOp :: HostOp lore op -> CSEM lore (HostOp lore op)
cseInOp (Kernel.SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
Kernel.SegOp (SegOp SegLevel lore -> HostOp lore op)
-> ReaderT (CSEState lore) Identity (SegOp SegLevel lore)
-> CSEM lore (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore
-> ReaderT (CSEState lore) Identity (SegOp SegLevel lore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp SegOp SegLevel lore
op
  cseInOp (Kernel.OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
Kernel.OtherOp (op -> HostOp lore op)
-> ReaderT (CSEState lore) Identity op
-> CSEM lore (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> ReaderT (CSEState lore) Identity op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
op
  cseInOp HostOp lore op
x = HostOp lore op -> CSEM lore (HostOp lore op)
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp lore op
x

instance
  ( ASTLore lore,
    Aliased lore,
    CSEInOp (Op lore),
    CSEInOp op
  ) =>
  CSEInOp (MC.MCOp lore op)
  where
  cseInOp :: MCOp lore op -> CSEM lore (MCOp lore op)
cseInOp (MC.ParOp Maybe (SegOp () lore)
par_op SegOp () lore
op) =
    Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
MC.ParOp (Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op)
-> ReaderT (CSEState lore) Identity (Maybe (SegOp () lore))
-> ReaderT (CSEState lore) Identity (SegOp () lore -> MCOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () lore -> ReaderT (CSEState lore) Identity (SegOp () lore))
-> Maybe (SegOp () lore)
-> ReaderT (CSEState lore) Identity (Maybe (SegOp () lore))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SegOp () lore -> ReaderT (CSEState lore) Identity (SegOp () lore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp Maybe (SegOp () lore)
par_op ReaderT (CSEState lore) Identity (SegOp () lore -> MCOp lore op)
-> ReaderT (CSEState lore) Identity (SegOp () lore)
-> CSEM lore (MCOp lore op)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOp () lore -> ReaderT (CSEState lore) Identity (SegOp () lore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp SegOp () lore
op
  cseInOp (MC.OtherOp op
op) =
    op -> MCOp lore op
forall lore op. op -> MCOp lore op
MC.OtherOp (op -> MCOp lore op)
-> ReaderT (CSEState lore) Identity op -> CSEM lore (MCOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> ReaderT (CSEState lore) Identity op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
op

instance
  (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
  CSEInOp (Kernel.SegOp lvl lore)
  where
  cseInOp :: SegOp lvl lore -> CSEM lore (SegOp lvl lore)
cseInOp =
    CSEM lore (SegOp lvl lore) -> CSEM lore (SegOp lvl lore)
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE
      (CSEM lore (SegOp lvl lore) -> CSEM lore (SegOp lvl lore))
-> (SegOp lvl lore -> CSEM lore (SegOp lvl lore))
-> SegOp lvl lore
-> CSEM lore (SegOp lvl lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lvl lore lore (ReaderT (CSEState lore) Identity)
-> SegOp lvl lore -> CSEM lore (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
Kernel.mapSegOpM
        ((SubExp -> ReaderT (CSEState lore) Identity SubExp)
-> (Lambda lore -> ReaderT (CSEState lore) Identity (Lambda lore))
-> (KernelBody lore
    -> ReaderT (CSEState lore) Identity (KernelBody lore))
-> (VName -> ReaderT (CSEState lore) Identity VName)
-> (lvl -> ReaderT (CSEState lore) Identity lvl)
-> SegOpMapper lvl lore lore (ReaderT (CSEState lore) Identity)
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
Kernel.SegOpMapper SubExp -> ReaderT (CSEState lore) Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore -> ReaderT (CSEState lore) Identity (Lambda lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Lambda lore -> CSEM lore (Lambda lore)
cseInLambda KernelBody lore
-> ReaderT (CSEState lore) Identity (KernelBody lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
KernelBody lore -> CSEM lore (KernelBody lore)
cseInKernelBody VName -> ReaderT (CSEState lore) Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return lvl -> ReaderT (CSEState lore) Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return)

cseInKernelBody ::
  (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
  Kernel.KernelBody lore ->
  CSEM lore (Kernel.KernelBody lore)
cseInKernelBody :: KernelBody lore -> CSEM lore (KernelBody lore)
cseInKernelBody (Kernel.KernelBody BodyDec lore
bodydec Stms lore
bnds [KernelResult]
res) = do
  Body BodyDec lore
_ Stms lore
bnds' [SubExp]
_ <- [Diet] -> BodyT lore -> CSEM lore (BodyT lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody ((KernelResult -> Diet) -> [KernelResult] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> KernelResult -> Diet
forall a b. a -> b -> a
const Diet
Observe) [KernelResult]
res) (BodyT lore -> CSEM lore (BodyT lore))
-> BodyT lore -> CSEM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
bodydec Stms lore
bnds []
  KernelBody lore -> CSEM lore (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> CSEM lore (KernelBody lore))
-> KernelBody lore -> CSEM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
Kernel.KernelBody BodyDec lore
bodydec Stms lore
bnds' [KernelResult]
res

instance CSEInOp op => CSEInOp (Memory.MemOp op) where
  cseInOp :: MemOp op -> CSEM lore (MemOp op)
cseInOp o :: MemOp op
o@Memory.Alloc {} = MemOp op -> CSEM lore (MemOp op)
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp op
o
  cseInOp (Memory.Inner op
k) = op -> MemOp op
forall inner. inner -> MemOp inner
Memory.Inner (op -> MemOp op)
-> ReaderT (CSEState lore) Identity op -> CSEM lore (MemOp op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CSEM Any op -> ReaderT (CSEState lore) Identity op
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (op -> CSEM Any op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
k)

instance
  ( ASTLore lore,
    CanBeAliased (Op lore),
    CSEInOp (OpWithAliases (Op lore))
  ) =>
  CSEInOp (SOAC.SOAC (Aliases lore))
  where
  cseInOp :: SOAC (Aliases lore) -> CSEM lore (SOAC (Aliases lore))
cseInOp = CSEM (Aliases lore) (SOAC (Aliases lore))
-> CSEM lore (SOAC (Aliases lore))
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (CSEM (Aliases lore) (SOAC (Aliases lore))
 -> CSEM lore (SOAC (Aliases lore)))
-> (SOAC (Aliases lore)
    -> CSEM (Aliases lore) (SOAC (Aliases lore)))
-> SOAC (Aliases lore)
-> CSEM lore (SOAC (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper
  (Aliases lore)
  (Aliases lore)
  (ReaderT (CSEState (Aliases lore)) Identity)
-> SOAC (Aliases lore) -> CSEM (Aliases lore) (SOAC (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
SOAC.mapSOACM ((SubExp -> ReaderT (CSEState (Aliases lore)) Identity SubExp)
-> (Lambda (Aliases lore)
    -> ReaderT
         (CSEState (Aliases lore)) Identity (Lambda (Aliases lore)))
-> (VName -> ReaderT (CSEState (Aliases lore)) Identity VName)
-> SOACMapper
     (Aliases lore)
     (Aliases lore)
     (ReaderT (CSEState (Aliases lore)) Identity)
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOAC.SOACMapper SubExp -> ReaderT (CSEState (Aliases lore)) Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda (Aliases lore)
-> ReaderT
     (CSEState (Aliases lore)) Identity (Lambda (Aliases lore))
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Lambda lore -> CSEM lore (Lambda lore)
cseInLambda VName -> ReaderT (CSEState (Aliases lore)) Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return)