{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
-- | This module implements common-subexpression elimination.  This
-- module does not actually remove the duplicate, but only replaces
-- one with a diference to the other.  E.g:
--
-- @
--   let a = x + y
--   let b = x + y
-- @
--
-- becomes:
--
-- @
--   let a = x + y
--   let b = a
-- @
--
-- After which copy propagation in the simplifier will actually remove
-- the definition of @b@.
--
-- Our CSE is still rather stupid.  No normalisation is performed, so
-- the expressions @x+y@ and @y+x@ will be considered distinct.
-- Furthermore, no expression with its own binding will be considered
-- equal to any other, since the variable names will be distinct.
-- This affects SOACs in particular.
module Futhark.Optimise.CSE
       ( performCSE
       , performCSEOnFunDef
       , performCSEOnStms
       , CSEInOp
       )
       where

import Control.Monad.Reader
import qualified Data.Map.Strict as M

import Futhark.Analysis.Alias
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.IR.Aliases
  (removeProgAliases, removeFunDefAliases, removeStmAliases,
   Aliases, consumedInStms)
import qualified Futhark.IR.Kernels.Kernel as Kernel
import qualified Futhark.IR.SOACS.SOAC as SOAC
import qualified Futhark.IR.Mem as Memory
import Futhark.Transform.Substitute
import Futhark.Pass

-- | Perform CSE on every function in a program.
performCSE :: (ASTLore lore, CanBeAliased (Op lore),
               CSEInOp (OpWithAliases (Op lore))) =>
              Bool -> Pass lore lore
performCSE :: Bool -> Pass lore lore
performCSE Bool
cse_arrays =
  String
-> String -> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"CSE" String
"Combine common subexpressions." ((Prog lore -> PassM (Prog lore)) -> Pass lore lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall a b. (a -> b) -> a -> b
$
  (Prog (Aliases lore) -> Prog lore)
-> PassM (Prog (Aliases lore)) -> PassM (Prog lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases lore) -> Prog lore
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases (PassM (Prog (Aliases lore)) -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog (Aliases lore)))
-> Prog lore
-> PassM (Prog lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  (Stms (Aliases lore) -> PassM (Stms (Aliases lore)))
-> (Stms (Aliases lore)
    -> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore)))
-> Prog (Aliases lore)
-> PassM (Prog (Aliases lore))
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms (Aliases lore) -> PassM (Stms (Aliases lore))
forall (f :: * -> *) lore.
(Applicative f, ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Stms lore -> f (Stms lore)
onConsts Stms (Aliases lore)
-> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore))
forall (f :: * -> *) lore p.
(Applicative f, ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
p -> FunDef lore -> f (FunDef lore)
onFun (Prog (Aliases lore) -> PassM (Prog (Aliases lore)))
-> (Prog lore -> Prog (Aliases lore))
-> Prog lore
-> PassM (Prog (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  Prog lore -> Prog (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
aliasAnalysis
  where onConsts :: Stms lore -> f (Stms lore)
onConsts Stms lore
stms =
          Stms lore -> f (Stms lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms lore -> f (Stms lore)) -> Stms lore -> f (Stms lore)
forall a b. (a -> b) -> a -> b
$ (Stms lore, ()) -> Stms lore
forall a b. (a, b) -> a
fst ((Stms lore, ()) -> Stms lore) -> (Stms lore, ()) -> Stms lore
forall a b. (a -> b) -> a -> b
$
          Reader (CSEState lore) (Stms lore, ())
-> CSEState lore -> (Stms lore, ())
forall r a. Reader r a -> r -> a
runReader (Names
-> [Stm lore]
-> CSEM lore ()
-> Reader (CSEState lore) (Stms lore, ())
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
stms) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) (() -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
          (Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays)
        onFun :: p -> FunDef lore -> f (FunDef lore)
onFun p
_ = FunDef lore -> f (FunDef lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef lore -> f (FunDef lore))
-> (FunDef lore -> FunDef lore) -> FunDef lore -> f (FunDef lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef lore -> FunDef lore
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays

-- | Perform CSE on a single function.
performCSEOnFunDef :: (ASTLore lore, CanBeAliased (Op lore),
                       CSEInOp (OpWithAliases (Op lore))) =>
                      Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef :: Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef Bool
cse_arrays =
  FunDef (Aliases lore) -> FunDef lore
forall lore.
CanBeAliased (Op lore) =>
FunDef (Aliases lore) -> FunDef lore
removeFunDefAliases (FunDef (Aliases lore) -> FunDef lore)
-> (FunDef lore -> FunDef (Aliases lore))
-> FunDef lore
-> FunDef lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef (Aliases lore) -> FunDef (Aliases lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays (FunDef (Aliases lore) -> FunDef (Aliases lore))
-> (FunDef lore -> FunDef (Aliases lore))
-> FunDef lore
-> FunDef (Aliases lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef lore -> FunDef (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
FunDef lore -> FunDef (Aliases lore)
analyseFun

-- | Perform CSE on some statements.
performCSEOnStms :: (ASTLore lore, CanBeAliased (Op lore),
                     CSEInOp (OpWithAliases (Op lore))) =>
                    Bool -> Stms lore -> Stms lore
performCSEOnStms :: Bool -> Stms lore -> Stms lore
performCSEOnStms Bool
cse_arrays =
  (Stm (Aliases lore) -> Stm lore)
-> Seq (Stm (Aliases lore)) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases (Seq (Stm (Aliases lore)) -> Stms lore)
-> (Stms lore -> Seq (Stm (Aliases lore)))
-> Stms lore
-> Stms lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Stm (Aliases lore)) -> Seq (Stm (Aliases lore))
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Stms lore -> Stms lore
f (Seq (Stm (Aliases lore)) -> Seq (Stm (Aliases lore)))
-> (Stms lore -> Seq (Stm (Aliases lore)))
-> Stms lore
-> Seq (Stm (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Seq (Stm (Aliases lore)), AliasesAndConsumed)
-> Seq (Stm (Aliases lore))
forall a b. (a, b) -> a
fst ((Seq (Stm (Aliases lore)), AliasesAndConsumed)
 -> Seq (Stm (Aliases lore)))
-> (Stms lore -> (Seq (Stm (Aliases lore)), AliasesAndConsumed))
-> Stms lore
-> Seq (Stm (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable
-> Stms lore -> (Seq (Stm (Aliases lore)), AliasesAndConsumed)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
analyseStms AliasTable
forall a. Monoid a => a
mempty
  where f :: Stms lore -> Stms lore
f Stms lore
stms =
          (Stms lore, ()) -> Stms lore
forall a b. (a, b) -> a
fst ((Stms lore, ()) -> Stms lore) -> (Stms lore, ()) -> Stms lore
forall a b. (a -> b) -> a -> b
$ Reader (CSEState lore) (Stms lore, ())
-> CSEState lore -> (Stms lore, ())
forall r a. Reader r a -> r -> a
runReader (Names
-> [Stm lore]
-> CSEM lore ()
-> Reader (CSEState lore) (Stms lore, ())
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
stms)
                           (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) (() -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
          (Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays)

cseInFunDef :: (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
               Bool -> FunDef lore -> FunDef lore
cseInFunDef :: Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays FunDef lore
fundec =
  FunDef lore
fundec { funDefBody :: BodyT lore
funDefBody =
              Reader (CSEState lore) (BodyT lore) -> CSEState lore -> BodyT lore
forall r a. Reader r a -> r -> a
runReader ([Diet] -> BodyT lore -> Reader (CSEState lore) (BodyT lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds (BodyT lore -> Reader (CSEState lore) (BodyT lore))
-> BodyT lore -> Reader (CSEState lore) (BodyT lore)
forall a b. (a -> b) -> a -> b
$ FunDef lore -> BodyT lore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef lore
fundec) (CSEState lore -> BodyT lore) -> CSEState lore -> BodyT lore
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays
         }
  where ds :: [Diet]
ds = (RetType lore -> Diet) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase ExtShape Uniqueness -> Diet)
-> (RetType lore -> TypeBase ExtShape Uniqueness)
-> RetType lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType lore -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf) ([RetType lore] -> [Diet]) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> a -> b
$ FunDef lore -> [RetType lore]
forall lore. FunDef lore -> [RetType lore]
funDefRetType FunDef lore
fundec

type CSEM lore = Reader (CSEState lore)

cseInBody :: (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
             [Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody :: [Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds (Body BodyDec lore
bodydec Stms lore
bnds Result
res) = do
  (Stms lore
bnds', Result
res') <-
    Names
-> [Stm lore] -> CSEM lore Result -> CSEM lore (Stms lore, Result)
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Names
res_cons Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
bnds) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
bnds) (CSEM lore Result -> CSEM lore (Stms lore, Result))
-> CSEM lore Result -> CSEM lore (Stms lore, Result)
forall a b. (a -> b) -> a -> b
$ do
    CSEState (ExpressionSubstitutions lore
_, NameSubstitutions
nsubsts) Bool
_ <- ReaderT (CSEState lore) Identity (CSEState lore)
forall r (m :: * -> *). MonadReader r m => m r
ask
    Result -> CSEM lore Result
forall (m :: * -> *) a. Monad m => a -> m a
return (Result -> CSEM lore Result) -> Result -> CSEM lore Result
forall a b. (a -> b) -> a -> b
$ NameSubstitutions -> Result -> Result
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Result
res
  Body lore -> CSEM lore (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> CSEM lore (Body lore))
-> Body lore -> CSEM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> Result -> Body lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec lore
bodydec Stms lore
bnds' Result
res'
  where res_cons :: Names
res_cons = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Diet -> SubExp -> Names) -> [Diet] -> Result -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Diet -> SubExp -> Names
forall a. FreeIn a => Diet -> a -> Names
consumeResult [Diet]
ds Result
res
        consumeResult :: Diet -> a -> Names
consumeResult Diet
Consume a
se = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
se
        consumeResult Diet
_ a
_ = Names
forall a. Monoid a => a
mempty

cseInLambda :: (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
               Lambda lore -> CSEM lore (Lambda lore)
cseInLambda :: Lambda lore -> CSEM lore (Lambda lore)
cseInLambda Lambda lore
lam = do
  Body lore
body' <- [Diet] -> Body lore -> CSEM lore (Body lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) ([Type] -> [Diet]) -> [Type] -> [Diet]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam) (Body lore -> CSEM lore (Body lore))
-> Body lore -> CSEM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  Lambda lore -> CSEM lore (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
lam { lambdaBody :: Body lore
lambdaBody = Body lore
body' }

cseInStms :: (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
             Names -> [Stm lore]
          -> CSEM lore a
          -> CSEM lore (Stms lore, a)
cseInStms :: Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms Names
_ [] CSEM lore a
m = do a
a <- CSEM lore a
m
                      (Stms lore, a) -> CSEM lore (Stms lore, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore
forall a. Monoid a => a
mempty, a
a)
cseInStms Names
consumed (Stm lore
bnd:[Stm lore]
bnds) CSEM lore a
m =
  Names
-> Stm lore
-> ([Stm lore] -> CSEM lore (Stms lore, a))
-> CSEM lore (Stms lore, a)
forall lore a.
ASTLore lore =>
Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a
cseInStm Names
consumed Stm lore
bnd (([Stm lore] -> CSEM lore (Stms lore, a))
 -> CSEM lore (Stms lore, a))
-> ([Stm lore] -> CSEM lore (Stms lore, a))
-> CSEM lore (Stms lore, a)
forall a b. (a -> b) -> a -> b
$ \[Stm lore]
bnd' -> do
    (Stms lore
bnds', a
a) <- Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
forall lore a.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms Names
consumed [Stm lore]
bnds CSEM lore a
m
    [Stm lore]
bnd'' <- (Stm lore -> ReaderT (CSEState lore) Identity (Stm lore))
-> [Stm lore] -> ReaderT (CSEState lore) Identity [Stm lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
nestedCSE [Stm lore]
bnd'
    (Stms lore, a) -> CSEM lore (Stms lore, a)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
bnd''Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
bnds', a
a)
  where nestedCSE :: Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
nestedCSE Stm lore
bnd' = do
          let ds :: [Diet]
ds = (PatElemT (LetDec lore) -> Diet)
-> [PatElemT (LetDec lore)] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec lore) -> Diet
forall dec. PatElemT dec -> Diet
patElemDiet ([PatElemT (LetDec lore)] -> [Diet])
-> [PatElemT (LetDec lore)] -> [Diet]
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT (LetDec lore) -> [PatElemT (LetDec lore)])
-> PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd'
          Exp lore
e <- Mapper lore lore (ReaderT (CSEState lore) Identity)
-> Exp lore -> ReaderT (CSEState lore) Identity (Exp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ([Diet] -> Mapper lore lore (ReaderT (CSEState lore) Identity)
forall tlore.
(ASTLore tlore, Aliased tlore, CSEInOp (Op tlore)) =>
[Diet] -> Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
cse [Diet]
ds) (Exp lore -> ReaderT (CSEState lore) Identity (Exp lore))
-> Exp lore -> ReaderT (CSEState lore) Identity (Exp lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd'
          Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Stm lore
bnd' { stmExp :: Exp lore
stmExp = Exp lore
e }

        cse :: [Diet] -> Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
cse [Diet]
ds = Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope tlore
-> Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore)
mapOnBody = (Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
-> Scope tlore
-> Body tlore
-> ReaderT (CSEState tlore) Identity (Body tlore)
forall a b. a -> b -> a
const ((Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
 -> Scope tlore
 -> Body tlore
 -> ReaderT (CSEState tlore) Identity (Body tlore))
-> (Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
-> Scope tlore
-> Body tlore
-> ReaderT (CSEState tlore) Identity (Body tlore)
forall a b. (a -> b) -> a -> b
$ [Diet]
-> Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds
                                , mapOnOp :: Op tlore -> ReaderT (CSEState tlore) Identity (Op tlore)
mapOnOp = Op tlore -> ReaderT (CSEState tlore) Identity (Op tlore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp
                                }

        patElemDiet :: PatElemT dec -> Diet
patElemDiet PatElemT dec
pe | PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Diet
Consume
                       | Bool
otherwise                        = Diet
Observe

cseInStm :: ASTLore lore =>
            Names -> Stm lore
         -> ([Stm lore] -> CSEM lore a)
         -> CSEM lore a
cseInStm :: Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a
cseInStm Names
consumed (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) Exp lore
e) [Stm lore] -> CSEM lore a
m = do
  CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays <- ReaderT (CSEState lore) Identity (CSEState lore)
forall r (m :: * -> *). MonadReader r m => m r
ask
  let e' :: Exp lore
e' = NameSubstitutions -> Exp lore -> Exp lore
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Exp lore
e
      pat' :: Pattern lore
pat' = NameSubstitutions -> Pattern lore -> Pattern lore
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Pattern lore
pat
  if (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> PatElemT (LetDec lore) -> Bool
forall dec. Typed dec => Bool -> PatElemT dec -> Bool
bad Bool
cse_arrays) ([PatElemT (LetDec lore)] -> Bool)
-> [PatElemT (LetDec lore)] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat then
    [Stm lore] -> CSEM lore a
m [Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat' (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) Exp lore
e']
    else
    case (ExpDec lore, Exp lore)
-> ExpressionSubstitutions lore -> Maybe (Pattern lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ExpDec lore
edec, Exp lore
e') ExpressionSubstitutions lore
esubsts of
      Just Pattern lore
subpat ->
        (CSEState lore -> CSEState lore) -> CSEM lore a -> CSEM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pattern lore -> Pattern lore -> CSEState lore -> CSEState lore
forall dec lore.
PatternT dec -> PatternT dec -> CSEState lore -> CSEState lore
addNameSubst Pattern lore
pat' Pattern lore
subpat) (CSEM lore a -> CSEM lore a) -> CSEM lore a -> CSEM lore a
forall a b. (a -> b) -> a -> b
$ do
          let lets :: [Stm lore]
lets =
                [ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
patElem']) (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
patElem
                | (VName
name,PatElemT (LetDec lore)
patElem) <- [VName]
-> [PatElemT (LetDec lore)] -> [(VName, PatElemT (LetDec lore))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat') ([PatElemT (LetDec lore)] -> [(VName, PatElemT (LetDec lore))])
-> [PatElemT (LetDec lore)] -> [(VName, PatElemT (LetDec lore))]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
subpat ,
                  let patElem' :: PatElemT (LetDec lore)
patElem' = PatElemT (LetDec lore)
patElem { patElemName :: VName
patElemName = VName
name }
                ]
          [Stm lore] -> CSEM lore a
m [Stm lore]
lets
      Maybe (Pattern lore)
_ -> (CSEState lore -> CSEState lore) -> CSEM lore a -> CSEM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pattern lore
-> ExpDec lore -> Exp lore -> CSEState lore -> CSEState lore
forall lore.
ASTLore lore =>
Pattern lore
-> ExpDec lore -> Exp lore -> CSEState lore -> CSEState lore
addExpSubst Pattern lore
pat' ExpDec lore
edec Exp lore
e') (CSEM lore a -> CSEM lore a) -> CSEM lore a -> CSEM lore a
forall a b. (a -> b) -> a -> b
$
           [Stm lore] -> CSEM lore a
m [Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat' (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
edec) Exp lore
e']

  where bad :: Bool -> PatElemT dec -> Bool
bad Bool
cse_arrays PatElemT dec
pe
          | Mem{} <- PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe = Bool
True
          | Array{} <- PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe, Bool -> Bool
not Bool
cse_arrays = Bool
True
          | PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Bool
True
          | Bool
otherwise = Bool
False

type ExpressionSubstitutions lore = M.Map
                                    (ExpDec lore, Exp lore)
                                    (Pattern lore)
type NameSubstitutions = M.Map VName VName

data CSEState lore = CSEState
                     { CSEState lore -> (ExpressionSubstitutions lore, NameSubstitutions)
_cseSubstitutions :: (ExpressionSubstitutions lore, NameSubstitutions)
                     , CSEState lore -> Bool
_cseArrays :: Bool
                     }

newCSEState :: Bool -> CSEState lore
newCSEState :: Bool -> CSEState lore
newCSEState = (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState (ExpressionSubstitutions lore
forall k a. Map k a
M.empty, NameSubstitutions
forall k a. Map k a
M.empty)

mkSubsts :: PatternT dec -> PatternT dec -> M.Map VName VName
mkSubsts :: PatternT dec -> PatternT dec -> NameSubstitutions
mkSubsts PatternT dec
pat PatternT dec
vs = [(VName, VName)] -> NameSubstitutions
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> NameSubstitutions)
-> [(VName, VName)] -> NameSubstitutions
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT dec
pat) (PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT dec
vs)

addNameSubst :: PatternT dec -> PatternT dec -> CSEState lore -> CSEState lore
addNameSubst :: PatternT dec -> PatternT dec -> CSEState lore -> CSEState lore
addNameSubst PatternT dec
pat PatternT dec
subpat (CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState (ExpressionSubstitutions lore
esubsts, PatternT dec -> PatternT dec -> NameSubstitutions
forall dec. PatternT dec -> PatternT dec -> NameSubstitutions
mkSubsts PatternT dec
pat PatternT dec
subpat NameSubstitutions -> NameSubstitutions -> NameSubstitutions
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` NameSubstitutions
nsubsts) Bool
cse_arrays

addExpSubst :: ASTLore lore =>
               Pattern lore -> ExpDec lore -> Exp lore
            -> CSEState lore
            -> CSEState lore
addExpSubst :: Pattern lore
-> ExpDec lore -> Exp lore -> CSEState lore -> CSEState lore
addExpSubst Pattern lore
pat ExpDec lore
edec Exp lore
e (CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState ((ExpDec lore, Exp lore)
-> Pattern lore
-> ExpressionSubstitutions lore
-> ExpressionSubstitutions lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (ExpDec lore
edec,Exp lore
e) Pattern lore
pat ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays

-- | The operations that permit CSE.
class CSEInOp op where
  -- | Perform CSE within any nested expressions.
  cseInOp :: op -> CSEM lore op

instance CSEInOp () where
  cseInOp :: () -> CSEM lore ()
cseInOp () = () -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

subCSE :: CSEM lore r -> CSEM otherlore r
subCSE :: CSEM lore r -> CSEM otherlore r
subCSE CSEM lore r
m = do
  CSEState (ExpressionSubstitutions otherlore, NameSubstitutions)
_ Bool
cse_arrays <- ReaderT (CSEState otherlore) Identity (CSEState otherlore)
forall r (m :: * -> *). MonadReader r m => m r
ask
  r -> CSEM otherlore r
forall (m :: * -> *) a. Monad m => a -> m a
return (r -> CSEM otherlore r) -> r -> CSEM otherlore r
forall a b. (a -> b) -> a -> b
$ CSEM lore r -> CSEState lore -> r
forall r a. Reader r a -> r -> a
runReader CSEM lore r
m (CSEState lore -> r) -> CSEState lore -> r
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays

instance (ASTLore lore, Aliased lore,
          CSEInOp (Op lore), CSEInOp op) => CSEInOp (Kernel.HostOp lore op) where
  cseInOp :: HostOp lore op -> CSEM lore (HostOp lore op)
cseInOp (Kernel.SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
Kernel.SegOp (SegOp SegLevel lore -> HostOp lore op)
-> ReaderT (CSEState lore) Identity (SegOp SegLevel lore)
-> CSEM lore (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore
-> ReaderT (CSEState lore) Identity (SegOp SegLevel lore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp SegOp SegLevel lore
op
  cseInOp (Kernel.OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
Kernel.OtherOp (op -> HostOp lore op)
-> ReaderT (CSEState lore) Identity op
-> CSEM lore (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> ReaderT (CSEState lore) Identity op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
op
  cseInOp HostOp lore op
x = HostOp lore op -> CSEM lore (HostOp lore op)
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp lore op
x

instance (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
         CSEInOp (Kernel.SegOp lvl lore) where
  cseInOp :: SegOp lvl lore -> CSEM lore (SegOp lvl lore)
cseInOp = CSEM lore (SegOp lvl lore) -> CSEM lore (SegOp lvl lore)
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (CSEM lore (SegOp lvl lore) -> CSEM lore (SegOp lvl lore))
-> (SegOp lvl lore -> CSEM lore (SegOp lvl lore))
-> SegOp lvl lore
-> CSEM lore (SegOp lvl lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
            SegOpMapper lvl lore lore (ReaderT (CSEState lore) Identity)
-> SegOp lvl lore -> CSEM lore (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
Kernel.mapSegOpM
            ((SubExp -> ReaderT (CSEState lore) Identity SubExp)
-> (Lambda lore -> ReaderT (CSEState lore) Identity (Lambda lore))
-> (KernelBody lore
    -> ReaderT (CSEState lore) Identity (KernelBody lore))
-> (VName -> ReaderT (CSEState lore) Identity VName)
-> (lvl -> ReaderT (CSEState lore) Identity lvl)
-> SegOpMapper lvl lore lore (ReaderT (CSEState lore) Identity)
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
Kernel.SegOpMapper SubExp -> ReaderT (CSEState lore) Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore -> ReaderT (CSEState lore) Identity (Lambda lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Lambda lore -> CSEM lore (Lambda lore)
cseInLambda KernelBody lore
-> ReaderT (CSEState lore) Identity (KernelBody lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
KernelBody lore -> CSEM lore (KernelBody lore)
cseInKernelBody VName -> ReaderT (CSEState lore) Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return lvl -> ReaderT (CSEState lore) Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return)

cseInKernelBody :: (ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
                   Kernel.KernelBody lore -> CSEM lore (Kernel.KernelBody lore)
cseInKernelBody :: KernelBody lore -> CSEM lore (KernelBody lore)
cseInKernelBody (Kernel.KernelBody BodyDec lore
bodydec Stms lore
bnds [KernelResult]
res) = do
  Body BodyDec lore
_ Stms lore
bnds' Result
_ <- [Diet] -> BodyT lore -> CSEM lore (BodyT lore)
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody ((KernelResult -> Diet) -> [KernelResult] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> KernelResult -> Diet
forall a b. a -> b -> a
const Diet
Observe) [KernelResult]
res) (BodyT lore -> CSEM lore (BodyT lore))
-> BodyT lore -> CSEM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> Result -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec lore
bodydec Stms lore
bnds []
  KernelBody lore -> CSEM lore (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> CSEM lore (KernelBody lore))
-> KernelBody lore -> CSEM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
Kernel.KernelBody BodyDec lore
bodydec Stms lore
bnds' [KernelResult]
res

instance CSEInOp op => CSEInOp (Memory.MemOp op) where
  cseInOp :: MemOp op -> CSEM lore (MemOp op)
cseInOp o :: MemOp op
o@Memory.Alloc{} = MemOp op -> CSEM lore (MemOp op)
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp op
o
  cseInOp (Memory.Inner op
k) = op -> MemOp op
forall inner. inner -> MemOp inner
Memory.Inner (op -> MemOp op)
-> ReaderT (CSEState lore) Identity op -> CSEM lore (MemOp op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CSEM Any op -> ReaderT (CSEState lore) Identity op
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (op -> CSEM Any op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
k)

instance (ASTLore lore,
          CanBeAliased (Op lore),
          CSEInOp (OpWithAliases (Op lore))) =>
         CSEInOp (SOAC.SOAC (Aliases lore)) where
  cseInOp :: SOAC (Aliases lore) -> CSEM lore (SOAC (Aliases lore))
cseInOp = CSEM (Aliases lore) (SOAC (Aliases lore))
-> CSEM lore (SOAC (Aliases lore))
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (CSEM (Aliases lore) (SOAC (Aliases lore))
 -> CSEM lore (SOAC (Aliases lore)))
-> (SOAC (Aliases lore)
    -> CSEM (Aliases lore) (SOAC (Aliases lore)))
-> SOAC (Aliases lore)
-> CSEM lore (SOAC (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper
  (Aliases lore)
  (Aliases lore)
  (ReaderT (CSEState (Aliases lore)) Identity)
-> SOAC (Aliases lore) -> CSEM (Aliases lore) (SOAC (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
SOAC.mapSOACM ((SubExp -> ReaderT (CSEState (Aliases lore)) Identity SubExp)
-> (Lambda (Aliases lore)
    -> ReaderT
         (CSEState (Aliases lore)) Identity (Lambda (Aliases lore)))
-> (VName -> ReaderT (CSEState (Aliases lore)) Identity VName)
-> SOACMapper
     (Aliases lore)
     (Aliases lore)
     (ReaderT (CSEState (Aliases lore)) Identity)
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOAC.SOACMapper SubExp -> ReaderT (CSEState (Aliases lore)) Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda (Aliases lore)
-> ReaderT
     (CSEState (Aliases lore)) Identity (Lambda (Aliases lore))
forall lore.
(ASTLore lore, Aliased lore, CSEInOp (Op lore)) =>
Lambda lore -> CSEM lore (Lambda lore)
cseInLambda VName -> ReaderT (CSEState (Aliases lore)) Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return)