{-# LANGUAGE UndecidableInstances #-}

-- | This module implements common-subexpression elimination.  This
-- module does not actually remove the duplicate, but only replaces
-- one with a diference to the other.  E.g:
--
-- @
--   let a = x + y
--   let b = x + y
-- @
--
-- becomes:
--
-- @
--   let a = x + y
--   let b = a
-- @
--
-- After which copy propagation in the simplifier will actually remove
-- the definition of @b@.
--
-- Our CSE is still rather stupid.  No normalisation is performed, so
-- the expressions @x+y@ and @y+x@ will be considered distinct.
-- Furthermore, no expression with its own binding will be considered
-- equal to any other, since the variable names will be distinct.
-- This affects SOACs in particular.
module Futhark.Optimise.CSE
  ( performCSE,
    performCSEOnFunDef,
    performCSEOnStms,
    CSEInOp,
  )
where

import Control.Monad.Reader
import Data.Map.Strict qualified as M
import Futhark.Analysis.Alias
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    consumedInStms,
    mkStmsAliases,
    removeFunDefAliases,
    removeProgAliases,
    removeStmAliases,
  )
import Futhark.IR.GPU qualified as GPU
import Futhark.IR.MC qualified as MC
import Futhark.IR.Mem qualified as Memory
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS.SOAC qualified as SOAC
import Futhark.Pass
import Futhark.Transform.Substitute

-- | Perform CSE on every function in a program.
--
-- If the boolean argument is false, the pass will not perform CSE on
-- expressions producing arrays. This should be disabled when the rep has
-- memory information, since at that point arrays have identity beyond their
-- value.
performCSE ::
  (AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
  Bool ->
  Pass rep rep
performCSE :: forall rep.
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool -> Pass rep rep
performCSE Bool
cse_arrays =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"CSE" String
"Combine common subexpressions." forall a b. (a -> b) -> a -> b
$ \Prog rep
prog ->
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall rep. RephraseOp (OpC rep) => Prog (Aliases rep) -> Prog rep
removeProgAliases
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
        (forall {f :: * -> *} {rep}.
(Applicative f, CSEInOp (OpC rep rep), Aliased rep) =>
Names -> Stms rep -> f (Stms rep)
onConsts (forall a. FreeIn a => a -> Names
freeIn (forall rep. Prog rep -> [FunDef rep]
progFuns Prog rep
prog)))
        forall {f :: * -> *} {rep} {p}.
(Applicative f, Aliased rep, CSEInOp (OpC rep rep)) =>
p -> FunDef rep -> f (FunDef rep)
onFun
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
aliasAnalysis
      forall a b. (a -> b) -> a -> b
$ Prog rep
prog
  where
    onConsts :: Names -> Stms rep -> f (Stms rep)
onConsts Names
free_in_funs Stms rep
stms = do
      let free_list :: [VName]
free_list = Names -> [VName]
namesToList Names
free_in_funs
          ([Names]
res_als, Names
stms_cons) = forall rep. Aliased rep => Stms rep -> Result -> ([Names], Names)
mkStmsAliases Stms rep
stms forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
free_list
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
        forall r a. Reader r a -> r -> a
runReader
          ( forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms
              (forall a. Monoid a => [a] -> a
mconcat [Names]
res_als forall a. Semigroup a => a -> a -> a
<> Names
stms_cons)
              (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)
              (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
          )
          (forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays)
    onFun :: p -> FunDef rep -> f (FunDef rep)
onFun p
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Bool -> FunDef rep -> FunDef rep
cseInFunDef Bool
cse_arrays

-- | Perform CSE on a single function.
--
-- If the boolean argument is false, the pass will not perform CSE on
-- expressions producing arrays. This should be disabled when the rep has
-- memory information, since at that point arrays have identity beyond their
-- value.
performCSEOnFunDef ::
  (AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
  Bool ->
  FunDef rep ->
  FunDef rep
performCSEOnFunDef :: forall rep.
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool -> FunDef rep -> FunDef rep
performCSEOnFunDef Bool
cse_arrays =
  forall rep.
RephraseOp (OpC rep) =>
FunDef (Aliases rep) -> FunDef rep
removeFunDefAliases forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Bool -> FunDef rep -> FunDef rep
cseInFunDef Bool
cse_arrays forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. AliasableRep rep => FunDef rep -> FunDef (Aliases rep)
analyseFun

-- | Perform CSE on some statements.
performCSEOnStms ::
  (AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
  Stms rep ->
  Stms rep
performCSEOnStms :: forall rep.
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Stms rep -> Stms rep
performCSEOnStms =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall rep. RephraseOp (OpC rep) => Stm (Aliases rep) -> Stm rep
removeStmAliases forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {rep}.
(Aliased rep, CSEInOp (OpC rep rep)) =>
Stms rep -> Stms rep
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms forall a. Monoid a => a
mempty
  where
    f :: Stms rep -> Stms rep
f Stms rep
stms =
      forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
        forall r a. Reader r a -> r -> a
runReader
          (forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms (forall rep. Aliased rep => Stms rep -> Names
consumedInStms Stms rep
stms) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms) (forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))
          -- It is never safe to CSE arrays in stms in isolation,
          -- because we might introduce additional aliasing.
          (forall rep. Bool -> CSEState rep
newCSEState Bool
False)

cseInFunDef ::
  (Aliased rep, CSEInOp (Op rep)) =>
  Bool ->
  FunDef rep ->
  FunDef rep
cseInFunDef :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Bool -> FunDef rep -> FunDef rep
cseInFunDef Bool
cse_arrays FunDef rep
fundec =
  FunDef rep
fundec
    { funDefBody :: Body rep
funDefBody =
        forall r a. Reader r a -> r -> a
runReader (forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody [Diet]
ds forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
fundec) forall a b. (a -> b) -> a -> b
$ forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays
    }
  where
    -- XXX: we treat every array result as a consumption here, because
    -- it is otherwise complicated to ensure we do not introduce more
    -- aliasing than specified by the return type. This is not a
    -- practical problem while we still perform such aggressive
    -- inlining.
    ds :: [Diet]
ds = forall a b. (a -> b) -> [a] -> [b]
map (forall {t}. DeclExtTyped t => t -> Diet
retDiet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef rep
fundec
    retDiet :: t -> Diet
retDiet t
t
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf t
t = Diet
Observe
      | Bool
otherwise = Diet
Consume

type CSEM rep = Reader (CSEState rep)

cseInBody ::
  (Aliased rep, CSEInOp (Op rep)) =>
  [Diet] ->
  Body rep ->
  CSEM rep (Body rep)
cseInBody :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody [Diet]
ds (Body BodyDec rep
bodydec Stms rep
stms Result
res) = do
  (Stms rep
stms', Result
res') <-
    forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms (Names
res_cons forall a. Semigroup a => a -> a -> a
<> Names
stms_cons) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms) forall a b. (a -> b) -> a -> b
$ do
      CSEState (ExpressionSubstitutions rep
_, NameSubstitutions
nsubsts) Bool
_ <- forall r (m :: * -> *). MonadReader r m => m r
ask
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
bodydec Stms rep
stms' Result
res'
  where
    ([Names]
res_als, Names
stms_cons) = forall rep. Aliased rep => Stms rep -> Result -> ([Names], Names)
mkStmsAliases Stms rep
stms Result
res
    res_cons :: Names
res_cons = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {p}. Monoid p => Diet -> p -> p
consumeResult [Diet]
ds [Names]
res_als
    consumeResult :: Diet -> p -> p
consumeResult Diet
Consume p
als = p
als
    consumeResult Diet
_ p
_ = forall a. Monoid a => a
mempty

cseInLambda ::
  (Aliased rep, CSEInOp (Op rep)) =>
  Lambda rep ->
  CSEM rep (Lambda rep)
cseInLambda :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep -> CSEM rep (Lambda rep)
cseInLambda Lambda rep
lam = do
  Body rep
body' <- forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Diet
Observe) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body'}

cseInStms ::
  forall rep a.
  (Aliased rep, CSEInOp (Op rep)) =>
  Names ->
  [Stm rep] ->
  CSEM rep a ->
  CSEM rep (Stms rep, a)
cseInStms :: forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms Names
_ [] CSEM rep a
m = do
  a
a <- CSEM rep a
m
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, a
a)
cseInStms Names
consumed (Stm rep
stm : [Stm rep]
stms) CSEM rep a
m =
  forall rep a.
ASTRep rep =>
Names -> Stm rep -> ([Stm rep] -> CSEM rep a) -> CSEM rep a
cseInStm Names
consumed Stm rep
stm forall a b. (a -> b) -> a -> b
$ \[Stm rep]
stm' -> do
    (Stms rep
stms', a
a) <- forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms Names
consumed [Stm rep]
stms CSEM rep a
m
    [Stm rep]
stm'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm rep -> ReaderT (CSEState rep) Identity (Stm rep)
nestedCSE [Stm rep]
stm'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
stm'' forall a. Semigroup a => a -> a -> a
<> Stms rep
stms', a
a)
  where
    nestedCSE :: Stm rep -> ReaderT (CSEState rep) Identity (Stm rep)
nestedCSE Stm rep
stm' = do
      let ds :: [Diet]
ds =
            case forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm' of
              Loop [(Param (FParamInfo rep), SubExp)]
merge LoopForm
_ Body rep
_ -> forall a b. (a -> b) -> [a] -> [b]
map (forall shape. TypeBase shape Uniqueness -> Diet
diet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
merge
              Exp rep
_ -> forall a b. (a -> b) -> [a] -> [b]
map forall {dec}. PatElem dec -> Diet
patElemDiet forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm'
      Exp rep
e <- forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ([Diet] -> Mapper rep rep (ReaderT (CSEState rep) Identity)
cse [Diet]
ds) forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm rep
stm' {stmExp :: Exp rep
stmExp = Exp rep
e}

    cse :: [Diet] -> Mapper rep rep (ReaderT (CSEState rep) Identity)
cse [Diet]
ds =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
        { mapOnBody :: Scope rep -> Body rep -> ReaderT (CSEState rep) Identity (Body rep)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody [Diet]
ds,
          mapOnOp :: Op rep -> ReaderT (CSEState rep) Identity (Op rep)
mapOnOp = forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp
        }

    patElemDiet :: PatElem dec -> Diet
patElemDiet PatElem dec
pe
      | forall dec. PatElem dec -> VName
patElemName PatElem dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Diet
Consume
      | Bool
otherwise = Diet
Observe

-- A small amount of normalisation of expressions that otherwise would
-- be different for pointless reasons.
normExp :: Exp lore -> Exp lore
normExp :: forall lore. Exp lore -> Exp lore
normExp (Apply Name
fname [(SubExp, Diet)]
args [(RetType lore, RetAls)]
ret (Safety
safety, SrcLoc
_, [SrcLoc]
_)) =
  forall rep.
Name
-> [(SubExp, Diet)]
-> [(RetType rep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args [(RetType lore, RetAls)]
ret (Safety
safety, forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
normExp Exp lore
e = Exp lore
e

cseInStm ::
  (ASTRep rep) =>
  Names ->
  Stm rep ->
  ([Stm rep] -> CSEM rep a) ->
  CSEM rep a
cseInStm :: forall rep a.
ASTRep rep =>
Names -> Stm rep -> ([Stm rep] -> CSEM rep a) -> CSEM rep a
cseInStm Names
consumed (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
edec) Exp rep
e) [Stm rep] -> CSEM rep a
m = do
  CSEState (ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays <- forall r (m :: * -> *). MonadReader r m => m r
ask
  let e' :: Exp rep
e' = forall lore. Exp lore -> Exp lore
normExp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Exp rep
e
      pat' :: Pat (LetDec rep)
pat' = forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Pat (LetDec rep)
pat
  if Bool -> Bool
not (forall {rep}. Exp rep -> Bool
alreadyAliases Exp rep
e) Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall {dec}. Typed dec => Bool -> PatElem dec -> Bool
bad Bool
cse_arrays) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
    then [Stm rep] -> CSEM rep a
m [forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
edec) Exp rep
e']
    else case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ExpDec rep
edec, Exp rep
e') ExpressionSubstitutions rep
esubsts of
      Just (Certs
subcs, Pat (LetDec rep)
subpat) -> do
        let subsumes :: Bool
subsumes = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Certs -> [VName]
unCerts Certs
subcs) (Certs -> [VName]
unCerts Certs
cs)
        -- We can only do a plain name substitution if it doesn't
        -- violate any certificate dependencies.
        forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (if Bool
subsumes then forall dec rep. Pat dec -> Pat dec -> CSEState rep -> CSEState rep
addNameSubst Pat (LetDec rep)
pat' Pat (LetDec rep)
subpat else forall a. a -> a
id) forall a b. (a -> b) -> a -> b
$ do
          let lets :: [Stm rep]
lets =
                [ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
patElem']) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
edec) forall a b. (a -> b) -> a -> b
$
                    forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
patElem)
                  | (VName
name, PatElem (LetDec rep)
patElem) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat') forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
subpat,
                    let patElem' :: PatElem (LetDec rep)
patElem' = PatElem (LetDec rep)
patElem {patElemName :: VName
patElemName = VName
name}
                ]
          [Stm rep] -> CSEM rep a
m [Stm rep]
lets
      Maybe (Certs, Pat (LetDec rep))
_ ->
        forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (forall rep.
ASTRep rep =>
Pat (LetDec rep)
-> ExpDec rep -> Certs -> Exp rep -> CSEState rep -> CSEState rep
addExpSubst Pat (LetDec rep)
pat' ExpDec rep
edec Certs
cs Exp rep
e') forall a b. (a -> b) -> a -> b
$
          [Stm rep] -> CSEM rep a
m [forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
edec) Exp rep
e']
  where
    alreadyAliases :: Exp rep -> Bool
alreadyAliases (BasicOp Index {}) = Bool
True
    alreadyAliases (BasicOp Reshape {}) = Bool
True
    alreadyAliases Exp rep
_ = Bool
False
    bad :: Bool -> PatElem dec -> Bool
bad Bool
cse_arrays PatElem dec
pe
      | Mem {} <- forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe = Bool
True
      | Array {} <- forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe, Bool -> Bool
not Bool
cse_arrays = Bool
True
      | forall dec. PatElem dec -> VName
patElemName PatElem dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Bool
True
      | Bool
otherwise = Bool
False

type ExpressionSubstitutions rep =
  M.Map
    (ExpDec rep, Exp rep)
    (Certs, Pat (LetDec rep))

type NameSubstitutions = M.Map VName VName

data CSEState rep = CSEState
  { forall rep.
CSEState rep -> (ExpressionSubstitutions rep, NameSubstitutions)
_cseSubstitutions :: (ExpressionSubstitutions rep, NameSubstitutions),
    forall rep. CSEState rep -> Bool
_cseArrays :: Bool
  }

newCSEState :: Bool -> CSEState rep
newCSEState :: forall rep. Bool -> CSEState rep
newCSEState = forall rep.
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
CSEState (forall k a. Map k a
M.empty, forall k a. Map k a
M.empty)

mkSubsts :: Pat dec -> Pat dec -> M.Map VName VName
mkSubsts :: forall dec. Pat dec -> Pat dec -> NameSubstitutions
mkSubsts Pat dec
pat Pat dec
vs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat dec
pat) (forall dec. Pat dec -> [VName]
patNames Pat dec
vs)

addNameSubst :: Pat dec -> Pat dec -> CSEState rep -> CSEState rep
addNameSubst :: forall dec rep. Pat dec -> Pat dec -> CSEState rep -> CSEState rep
addNameSubst Pat dec
pat Pat dec
subpat (CSEState (ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  forall rep.
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
CSEState (ExpressionSubstitutions rep
esubsts, forall dec. Pat dec -> Pat dec -> NameSubstitutions
mkSubsts Pat dec
pat Pat dec
subpat forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` NameSubstitutions
nsubsts) Bool
cse_arrays

addExpSubst ::
  (ASTRep rep) =>
  Pat (LetDec rep) ->
  ExpDec rep ->
  Certs ->
  Exp rep ->
  CSEState rep ->
  CSEState rep
addExpSubst :: forall rep.
ASTRep rep =>
Pat (LetDec rep)
-> ExpDec rep -> Certs -> Exp rep -> CSEState rep -> CSEState rep
addExpSubst Pat (LetDec rep)
pat ExpDec rep
edec Certs
cs Exp rep
e (CSEState (ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  forall rep.
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
CSEState (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (ExpDec rep
edec, Exp rep
e) (Certs
cs, Pat (LetDec rep)
pat) ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays

-- | The operations that permit CSE.
class CSEInOp op where
  -- | Perform CSE within any nested expressions.
  cseInOp :: op -> CSEM rep op

instance CSEInOp (NoOp rep) where
  cseInOp :: forall rep. NoOp rep -> CSEM rep (NoOp rep)
cseInOp NoOp rep
NoOp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (rep :: k). NoOp rep
NoOp

subCSE :: CSEM rep r -> CSEM otherrep r
subCSE :: forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE CSEM rep r
m = do
  CSEState (ExpressionSubstitutions otherrep, NameSubstitutions)
_ Bool
cse_arrays <- forall r (m :: * -> *). MonadReader r m => m r
ask
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall r a. Reader r a -> r -> a
runReader CSEM rep r
m forall a b. (a -> b) -> a -> b
$ forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays

instance
  ( Aliased rep,
    CSEInOp (Op rep),
    CSEInOp (op rep)
  ) =>
  CSEInOp (GPU.HostOp op rep)
  where
  cseInOp :: forall rep. HostOp op rep -> CSEM rep (HostOp op rep)
cseInOp (GPU.SegOp SegOp SegLevel rep
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
GPU.SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp SegOp SegLevel rep
op
  cseInOp (GPU.OtherOp op rep
op) = forall (op :: * -> *) rep. op rep -> HostOp op rep
GPU.OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp op rep
op
  cseInOp (GPU.GPUBody [Type]
types Body rep
body) =
    forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPU.GPUBody [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Diet
Observe) [Type]
types) Body rep
body
  cseInOp HostOp op rep
x = forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp op rep
x

instance
  ( Aliased rep,
    CSEInOp (Op rep),
    CSEInOp (op rep)
  ) =>
  CSEInOp (MC.MCOp op rep)
  where
  cseInOp :: forall rep. MCOp op rep -> CSEM rep (MCOp op rep)
cseInOp (MC.ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) =
    forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
MC.ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp Maybe (SegOp () rep)
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp SegOp () rep
op
  cseInOp (MC.OtherOp op rep
op) =
    forall (op :: * -> *) rep. op rep -> MCOp op rep
MC.OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp op rep
op

instance
  (Aliased rep, CSEInOp (Op rep)) =>
  CSEInOp (GPU.SegOp lvl rep)
  where
  cseInOp :: forall rep. SegOp lvl rep -> CSEM rep (SegOp lvl rep)
cseInOp =
    forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
GPU.mapSegOpM
        (forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
GPU.SegOpMapper forall (f :: * -> *) a. Applicative f => a -> f a
pure forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep -> CSEM rep (Lambda rep)
cseInLambda forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
KernelBody rep -> CSEM rep (KernelBody rep)
cseInKernelBody forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a. Applicative f => a -> f a
pure)

cseInKernelBody ::
  (Aliased rep, CSEInOp (Op rep)) =>
  GPU.KernelBody rep ->
  CSEM rep (GPU.KernelBody rep)
cseInKernelBody :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
KernelBody rep -> CSEM rep (KernelBody rep)
cseInKernelBody (GPU.KernelBody BodyDec rep
bodydec Stms rep
stms [KernelResult]
res) = do
  Body BodyDec rep
_ Stms rep
stms' Result
_ <- forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Diet
Observe) [KernelResult]
res) forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
bodydec Stms rep
stms []
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
GPU.KernelBody BodyDec rep
bodydec Stms rep
stms' [KernelResult]
res

instance (CSEInOp (op rep)) => CSEInOp (Memory.MemOp op rep) where
  cseInOp :: forall rep. MemOp op rep -> CSEM rep (MemOp op rep)
cseInOp o :: MemOp op rep
o@Memory.Alloc {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp op rep
o
  cseInOp (Memory.Inner op rep
k) = forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Memory.Inner forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE (forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp op rep
k)

instance
  (AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
  CSEInOp (SOAC.SOAC (Aliases rep))
  where
  cseInOp :: forall rep. SOAC (Aliases rep) -> CSEM rep (SOAC (Aliases rep))
cseInOp = forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
SOAC.mapSOACM (forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOAC.SOACMapper forall (f :: * -> *) a. Applicative f => a -> f a
pure forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep -> CSEM rep (Lambda rep)
cseInLambda forall (f :: * -> *) a. Applicative f => a -> f a
pure)