{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
-- | This module implements common-subexpression elimination.  This
-- module does not actually remove the duplicate, but only replaces
-- one with a diference to the other.  E.g:
--
-- @
--   let a = x + y
--   let b = x + y
-- @
--
-- becomes:
--
-- @
--   let a = x + y
--   let b = a
-- @
--
-- After which copy propagation in the simplifier will actually remove
-- the definition of @b@.
--
-- Our CSE is still rather stupid.  No normalisation is performed, so
-- the expressions @x+y@ and @y+x@ will be considered distinct.
-- Furthermore, no expression with its own binding will be considered
-- equal to any other, since the variable names will be distinct.
-- This affects SOACs in particular.
module Futhark.Optimise.CSE
       ( performCSE
       , performCSEOnFunDef
       , performCSEOnStms
       , CSEInOp
       )
       where

import Control.Monad.Reader
import qualified Data.Map.Strict as M

import Futhark.Analysis.Alias
import Futhark.Representation.AST
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
  (removeProgAliases, removeFunDefAliases, removeStmAliases,
   Aliases, consumedInStms)
import qualified Futhark.Representation.Kernels.Kernel as Kernel
import qualified Futhark.Representation.SOACS.SOAC as SOAC
import qualified Futhark.Representation.Mem as Memory
import Futhark.Transform.Substitute
import Futhark.Pass

-- | Perform CSE on every function in a program.
performCSE :: (Attributes lore, CanBeAliased (Op lore),
               CSEInOp (OpWithAliases (Op lore))) =>
              Bool -> Pass lore lore
performCSE :: Bool -> Pass lore lore
performCSE Bool
cse_arrays =
  String
-> String -> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"CSE" String
"Combine common subexpressions." ((Prog lore -> PassM (Prog lore)) -> Pass lore lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall a b. (a -> b) -> a -> b
$
  (Prog (Aliases lore) -> Prog lore)
-> PassM (Prog (Aliases lore)) -> PassM (Prog lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases lore) -> Prog lore
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases (PassM (Prog (Aliases lore)) -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog (Aliases lore)))
-> Prog lore
-> PassM (Prog lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  (Stms (Aliases lore) -> PassM (Stms (Aliases lore)))
-> (Stms (Aliases lore)
    -> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore)))
-> Prog (Aliases lore)
-> PassM (Prog (Aliases lore))
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms (Aliases lore) -> PassM (Stms (Aliases lore))
forall (f :: * -> *) lore.
(Applicative f, Attributes lore, Aliased lore,
 CSEInOp (Op lore)) =>
Stms lore -> f (Stms lore)
onConsts Stms (Aliases lore)
-> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore))
forall (f :: * -> *) lore p.
(Applicative f, Attributes lore, Aliased lore,
 CSEInOp (Op lore)) =>
p -> FunDef lore -> f (FunDef lore)
onFun (Prog (Aliases lore) -> PassM (Prog (Aliases lore)))
-> (Prog lore -> Prog (Aliases lore))
-> Prog lore
-> PassM (Prog (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  Prog lore -> Prog (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
aliasAnalysis
  where onConsts :: Stms lore -> f (Stms lore)
onConsts Stms lore
stms =
          Stms lore -> f (Stms lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms lore -> f (Stms lore)) -> Stms lore -> f (Stms lore)
forall a b. (a -> b) -> a -> b
$ (Stms lore, ()) -> Stms lore
forall a b. (a, b) -> a
fst ((Stms lore, ()) -> Stms lore) -> (Stms lore, ()) -> Stms lore
forall a b. (a -> b) -> a -> b
$
          Reader (CSEState lore) (Stms lore, ())
-> CSEState lore -> (Stms lore, ())
forall r a. Reader r a -> r -> a
runReader (Names
-> [Stm lore]
-> CSEM lore ()
-> Reader (CSEState lore) (Stms lore, ())
forall lore a.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
stms) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) (() -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
          (Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays)
        onFun :: p -> FunDef lore -> f (FunDef lore)
onFun p
_ = FunDef lore -> f (FunDef lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef lore -> f (FunDef lore))
-> (FunDef lore -> FunDef lore) -> FunDef lore -> f (FunDef lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef lore -> FunDef lore
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays

-- | Perform CSE on a single function.
performCSEOnFunDef :: (Attributes lore, CanBeAliased (Op lore),
                       CSEInOp (OpWithAliases (Op lore))) =>
                      Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef :: Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef Bool
cse_arrays =
  FunDef (Aliases lore) -> FunDef lore
forall lore.
CanBeAliased (Op lore) =>
FunDef (Aliases lore) -> FunDef lore
removeFunDefAliases (FunDef (Aliases lore) -> FunDef lore)
-> (FunDef lore -> FunDef (Aliases lore))
-> FunDef lore
-> FunDef lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef (Aliases lore) -> FunDef (Aliases lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays (FunDef (Aliases lore) -> FunDef (Aliases lore))
-> (FunDef lore -> FunDef (Aliases lore))
-> FunDef lore
-> FunDef (Aliases lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef lore -> FunDef (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
FunDef lore -> FunDef (Aliases lore)
analyseFun

-- | Perform CSE on some statements.
performCSEOnStms :: (Attributes lore, CanBeAliased (Op lore),
                     CSEInOp (OpWithAliases (Op lore))) =>
                    Bool -> Stms lore -> Stms lore
performCSEOnStms :: Bool -> Stms lore -> Stms lore
performCSEOnStms Bool
cse_arrays =
  (Stm (Aliases lore) -> Stm lore)
-> Seq (Stm (Aliases lore)) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases (Seq (Stm (Aliases lore)) -> Stms lore)
-> (Stms lore -> Seq (Stm (Aliases lore)))
-> Stms lore
-> Stms lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Stm (Aliases lore)) -> Seq (Stm (Aliases lore))
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Stms lore -> Stms lore
f (Seq (Stm (Aliases lore)) -> Seq (Stm (Aliases lore)))
-> (Stms lore -> Seq (Stm (Aliases lore)))
-> Stms lore
-> Seq (Stm (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Seq (Stm (Aliases lore)), AliasesAndConsumed)
-> Seq (Stm (Aliases lore))
forall a b. (a, b) -> a
fst ((Seq (Stm (Aliases lore)), AliasesAndConsumed)
 -> Seq (Stm (Aliases lore)))
-> (Stms lore -> (Seq (Stm (Aliases lore)), AliasesAndConsumed))
-> Stms lore
-> Seq (Stm (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable
-> Stms lore -> (Seq (Stm (Aliases lore)), AliasesAndConsumed)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
analyseStms AliasTable
forall a. Monoid a => a
mempty
  where f :: Stms lore -> Stms lore
f Stms lore
stms =
          (Stms lore, ()) -> Stms lore
forall a b. (a, b) -> a
fst ((Stms lore, ()) -> Stms lore) -> (Stms lore, ()) -> Stms lore
forall a b. (a -> b) -> a -> b
$ Reader (CSEState lore) (Stms lore, ())
-> CSEState lore -> (Stms lore, ())
forall r a. Reader r a -> r -> a
runReader (Names
-> [Stm lore]
-> CSEM lore ()
-> Reader (CSEState lore) (Stms lore, ())
forall lore a.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
stms)
                           (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) (() -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
          (Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays)

cseInFunDef :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
               Bool -> FunDef lore -> FunDef lore
cseInFunDef :: Bool -> FunDef lore -> FunDef lore
cseInFunDef Bool
cse_arrays FunDef lore
fundec =
  FunDef lore
fundec { funDefBody :: BodyT lore
funDefBody =
              Reader (CSEState lore) (BodyT lore) -> CSEState lore -> BodyT lore
forall r a. Reader r a -> r -> a
runReader ([Diet] -> BodyT lore -> Reader (CSEState lore) (BodyT lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds (BodyT lore -> Reader (CSEState lore) (BodyT lore))
-> BodyT lore -> Reader (CSEState lore) (BodyT lore)
forall a b. (a -> b) -> a -> b
$ FunDef lore -> BodyT lore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef lore
fundec) (CSEState lore -> BodyT lore) -> CSEState lore -> BodyT lore
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays
         }
  where ds :: [Diet]
ds = (RetType lore -> Diet) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase ExtShape Uniqueness -> Diet)
-> (RetType lore -> TypeBase ExtShape Uniqueness)
-> RetType lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType lore -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf) ([RetType lore] -> [Diet]) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> a -> b
$ FunDef lore -> [RetType lore]
forall lore. FunDef lore -> [RetType lore]
funDefRetType FunDef lore
fundec

type CSEM lore = Reader (CSEState lore)

cseInBody :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
             [Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody :: [Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds (Body BodyAttr lore
bodyattr Stms lore
bnds Result
res) = do
  (Stms lore
bnds', Result
res') <-
    Names
-> [Stm lore] -> CSEM lore Result -> CSEM lore (Stms lore, Result)
forall lore a.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms (Names
res_cons Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms lore -> Names
forall lore. Aliased lore => Stms lore -> Names
consumedInStms Stms lore
bnds) (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
bnds) (CSEM lore Result -> CSEM lore (Stms lore, Result))
-> CSEM lore Result -> CSEM lore (Stms lore, Result)
forall a b. (a -> b) -> a -> b
$ do
    CSEState (ExpressionSubstitutions lore
_, NameSubstitutions
nsubsts) Bool
_ <- ReaderT (CSEState lore) Identity (CSEState lore)
forall r (m :: * -> *). MonadReader r m => m r
ask
    Result -> CSEM lore Result
forall (m :: * -> *) a. Monad m => a -> m a
return (Result -> CSEM lore Result) -> Result -> CSEM lore Result
forall a b. (a -> b) -> a -> b
$ NameSubstitutions -> Result -> Result
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Result
res
  Body lore -> CSEM lore (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> CSEM lore (Body lore))
-> Body lore -> CSEM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms lore -> Result -> Body lore
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr lore
bodyattr Stms lore
bnds' Result
res'
  where res_cons :: Names
res_cons = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Diet -> SubExp -> Names) -> [Diet] -> Result -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Diet -> SubExp -> Names
forall a. FreeIn a => Diet -> a -> Names
consumeResult [Diet]
ds Result
res
        consumeResult :: Diet -> a -> Names
consumeResult Diet
Consume a
se = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
se
        consumeResult Diet
_ a
_ = Names
forall a. Monoid a => a
mempty

cseInLambda :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
               Lambda lore -> CSEM lore (Lambda lore)
cseInLambda :: Lambda lore -> CSEM lore (Lambda lore)
cseInLambda Lambda lore
lam = do
  Body lore
body' <- [Diet] -> Body lore -> CSEM lore (Body lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) ([Type] -> [Diet]) -> [Type] -> [Diet]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam) (Body lore -> CSEM lore (Body lore))
-> Body lore -> CSEM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  Lambda lore -> CSEM lore (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
lam { lambdaBody :: Body lore
lambdaBody = Body lore
body' }

cseInStms :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
             Names -> [Stm lore]
          -> CSEM lore a
          -> CSEM lore (Stms lore, a)
cseInStms :: Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms Names
_ [] CSEM lore a
m = do a
a <- CSEM lore a
m
                      (Stms lore, a) -> CSEM lore (Stms lore, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore
forall a. Monoid a => a
mempty, a
a)
cseInStms Names
consumed (Stm lore
bnd:[Stm lore]
bnds) CSEM lore a
m =
  Names
-> Stm lore
-> ([Stm lore] -> CSEM lore (Stms lore, a))
-> CSEM lore (Stms lore, a)
forall lore a.
Attributes lore =>
Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a
cseInStm Names
consumed Stm lore
bnd (([Stm lore] -> CSEM lore (Stms lore, a))
 -> CSEM lore (Stms lore, a))
-> ([Stm lore] -> CSEM lore (Stms lore, a))
-> CSEM lore (Stms lore, a)
forall a b. (a -> b) -> a -> b
$ \[Stm lore]
bnd' -> do
    (Stms lore
bnds', a
a) <- Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
forall lore a.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Names -> [Stm lore] -> CSEM lore a -> CSEM lore (Stms lore, a)
cseInStms Names
consumed [Stm lore]
bnds CSEM lore a
m
    [Stm lore]
bnd'' <- (Stm lore -> ReaderT (CSEState lore) Identity (Stm lore))
-> [Stm lore] -> ReaderT (CSEState lore) Identity [Stm lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
nestedCSE [Stm lore]
bnd'
    (Stms lore, a) -> CSEM lore (Stms lore, a)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
bnd''Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
bnds', a
a)
  where nestedCSE :: Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
nestedCSE Stm lore
bnd' = do
          let ds :: [Diet]
ds = (PatElemT (LetAttr lore) -> Diet)
-> [PatElemT (LetAttr lore)] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetAttr lore) -> Diet
forall attr. PatElemT attr -> Diet
patElemDiet ([PatElemT (LetAttr lore)] -> [Diet])
-> [PatElemT (LetAttr lore)] -> [Diet]
forall a b. (a -> b) -> a -> b
$ PatternT (LetAttr lore) -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT (LetAttr lore) -> [PatElemT (LetAttr lore)])
-> PatternT (LetAttr lore) -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd'
          Exp lore
e <- Mapper lore lore (ReaderT (CSEState lore) Identity)
-> Exp lore -> ReaderT (CSEState lore) Identity (Exp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ([Diet] -> Mapper lore lore (ReaderT (CSEState lore) Identity)
forall tlore.
(Attributes tlore, Aliased tlore, CSEInOp (Op tlore)) =>
[Diet] -> Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
cse [Diet]
ds) (Exp lore -> ReaderT (CSEState lore) Identity (Exp lore))
-> Exp lore -> ReaderT (CSEState lore) Identity (Exp lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd'
          Stm lore -> ReaderT (CSEState lore) Identity (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Stm lore
bnd' { stmExp :: Exp lore
stmExp = Exp lore
e }

        cse :: [Diet] -> Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
cse [Diet]
ds = Mapper tlore tlore (ReaderT (CSEState tlore) Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope tlore
-> Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore)
mapOnBody = (Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
-> Scope tlore
-> Body tlore
-> ReaderT (CSEState tlore) Identity (Body tlore)
forall a b. a -> b -> a
const ((Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
 -> Scope tlore
 -> Body tlore
 -> ReaderT (CSEState tlore) Identity (Body tlore))
-> (Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore))
-> Scope tlore
-> Body tlore
-> ReaderT (CSEState tlore) Identity (Body tlore)
forall a b. (a -> b) -> a -> b
$ [Diet]
-> Body tlore -> ReaderT (CSEState tlore) Identity (Body tlore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody [Diet]
ds
                                , mapOnOp :: Op tlore -> ReaderT (CSEState tlore) Identity (Op tlore)
mapOnOp = Op tlore -> ReaderT (CSEState tlore) Identity (Op tlore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp
                                }

        patElemDiet :: PatElemT attr -> Diet
patElemDiet PatElemT attr
pe | PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe VName -> Names -> Bool
`nameIn` Names
consumed = Diet
Consume
                       | Bool
otherwise                        = Diet
Observe

cseInStm :: Attributes lore =>
            Names -> Stm lore
         -> ([Stm lore] -> CSEM lore a)
         -> CSEM lore a
cseInStm :: Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a
cseInStm Names
consumed (Let Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
eattr) Exp lore
e) [Stm lore] -> CSEM lore a
m = do
  CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays <- ReaderT (CSEState lore) Identity (CSEState lore)
forall r (m :: * -> *). MonadReader r m => m r
ask
  let e' :: Exp lore
e' = NameSubstitutions -> Exp lore -> Exp lore
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Exp lore
e
      pat' :: Pattern lore
pat' = NameSubstitutions -> Pattern lore -> Pattern lore
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Pattern lore
pat
  if (PatElemT (LetAttr lore) -> Bool)
-> [PatElemT (LetAttr lore)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> PatElemT (LetAttr lore) -> Bool
forall attr. Typed attr => Bool -> PatElemT attr -> Bool
bad Bool
cse_arrays) ([PatElemT (LetAttr lore)] -> Bool)
-> [PatElemT (LetAttr lore)] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements Pattern lore
pat then
    [Stm lore] -> CSEM lore a
m [Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern lore
pat' (Certificates -> ExpAttr lore -> StmAux (ExpAttr lore)
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ExpAttr lore
eattr) Exp lore
e']
    else
    case (ExpAttr lore, Exp lore)
-> ExpressionSubstitutions lore -> Maybe (Pattern lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ExpAttr lore
eattr, Exp lore
e') ExpressionSubstitutions lore
esubsts of
      Just Pattern lore
subpat ->
        (CSEState lore -> CSEState lore) -> CSEM lore a -> CSEM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pattern lore -> Pattern lore -> CSEState lore -> CSEState lore
forall attr lore.
PatternT attr -> PatternT attr -> CSEState lore -> CSEState lore
addNameSubst Pattern lore
pat' Pattern lore
subpat) (CSEM lore a -> CSEM lore a) -> CSEM lore a -> CSEM lore a
forall a b. (a -> b) -> a -> b
$ do
          let lets :: [Stm lore]
lets =
                [ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetAttr lore)]
-> [PatElemT (LetAttr lore)] -> Pattern lore
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (LetAttr lore)
patElem']) (Certificates -> ExpAttr lore -> StmAux (ExpAttr lore)
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ExpAttr lore
eattr) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
patElem
                | (VName
name,PatElemT (LetAttr lore)
patElem) <- [VName]
-> [PatElemT (LetAttr lore)] -> [(VName, PatElemT (LetAttr lore))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat') ([PatElemT (LetAttr lore)] -> [(VName, PatElemT (LetAttr lore))])
-> [PatElemT (LetAttr lore)] -> [(VName, PatElemT (LetAttr lore))]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
subpat ,
                  let patElem' :: PatElemT (LetAttr lore)
patElem' = PatElemT (LetAttr lore)
patElem { patElemName :: VName
patElemName = VName
name }
                ]
          [Stm lore] -> CSEM lore a
m [Stm lore]
lets
      Maybe (Pattern lore)
_ -> (CSEState lore -> CSEState lore) -> CSEM lore a -> CSEM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pattern lore
-> ExpAttr lore -> Exp lore -> CSEState lore -> CSEState lore
forall lore.
Attributes lore =>
Pattern lore
-> ExpAttr lore -> Exp lore -> CSEState lore -> CSEState lore
addExpSubst Pattern lore
pat' ExpAttr lore
eattr Exp lore
e') (CSEM lore a -> CSEM lore a) -> CSEM lore a -> CSEM lore a
forall a b. (a -> b) -> a -> b
$
           [Stm lore] -> CSEM lore a
m [Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern lore
pat' (Certificates -> ExpAttr lore -> StmAux (ExpAttr lore)
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ExpAttr lore
eattr) Exp lore
e']

  where bad :: Bool -> PatElemT attr -> Bool
bad Bool
cse_arrays PatElemT attr
pe
          | Mem{} <- PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe = Bool
True
          | Array{} <- PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe, Bool -> Bool
not Bool
cse_arrays = Bool
True
          | PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe VName -> Names -> Bool
`nameIn` Names
consumed = Bool
True
          | Bool
otherwise = Bool
False

type ExpressionSubstitutions lore = M.Map
                                    (ExpAttr lore, Exp lore)
                                    (Pattern lore)
type NameSubstitutions = M.Map VName VName

data CSEState lore = CSEState
                     { CSEState lore -> (ExpressionSubstitutions lore, NameSubstitutions)
_cseSubstitutions :: (ExpressionSubstitutions lore, NameSubstitutions)
                     , CSEState lore -> Bool
_cseArrays :: Bool
                     }

newCSEState :: Bool -> CSEState lore
newCSEState :: Bool -> CSEState lore
newCSEState = (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState (ExpressionSubstitutions lore
forall k a. Map k a
M.empty, NameSubstitutions
forall k a. Map k a
M.empty)

mkSubsts :: PatternT attr -> PatternT attr -> M.Map VName VName
mkSubsts :: PatternT attr -> PatternT attr -> NameSubstitutions
mkSubsts PatternT attr
pat PatternT attr
vs = [(VName, VName)] -> NameSubstitutions
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> NameSubstitutions)
-> [(VName, VName)] -> NameSubstitutions
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT attr -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT attr
pat) (PatternT attr -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT attr
vs)

addNameSubst :: PatternT attr -> PatternT attr -> CSEState lore -> CSEState lore
addNameSubst :: PatternT attr -> PatternT attr -> CSEState lore -> CSEState lore
addNameSubst PatternT attr
pat PatternT attr
subpat (CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState (ExpressionSubstitutions lore
esubsts, PatternT attr -> PatternT attr -> NameSubstitutions
forall attr. PatternT attr -> PatternT attr -> NameSubstitutions
mkSubsts PatternT attr
pat PatternT attr
subpat NameSubstitutions -> NameSubstitutions -> NameSubstitutions
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` NameSubstitutions
nsubsts) Bool
cse_arrays

addExpSubst :: Attributes lore =>
               Pattern lore -> ExpAttr lore -> Exp lore
            -> CSEState lore
            -> CSEState lore
addExpSubst :: Pattern lore
-> ExpAttr lore -> Exp lore -> CSEState lore -> CSEState lore
addExpSubst Pattern lore
pat ExpAttr lore
eattr Exp lore
e (CSEState (ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
  (ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
forall lore.
(ExpressionSubstitutions lore, NameSubstitutions)
-> Bool -> CSEState lore
CSEState ((ExpAttr lore, Exp lore)
-> Pattern lore
-> ExpressionSubstitutions lore
-> ExpressionSubstitutions lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (ExpAttr lore
eattr,Exp lore
e) Pattern lore
pat ExpressionSubstitutions lore
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays

-- | The operations that permit CSE.
class CSEInOp op where
  -- | Perform CSE within any nested expressions.
  cseInOp :: op -> CSEM lore op

instance CSEInOp () where
  cseInOp :: () -> CSEM lore ()
cseInOp () = () -> CSEM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

subCSE :: CSEM lore r -> CSEM otherlore r
subCSE :: CSEM lore r -> CSEM otherlore r
subCSE CSEM lore r
m = do
  CSEState (ExpressionSubstitutions otherlore, NameSubstitutions)
_ Bool
cse_arrays <- ReaderT (CSEState otherlore) Identity (CSEState otherlore)
forall r (m :: * -> *). MonadReader r m => m r
ask
  r -> CSEM otherlore r
forall (m :: * -> *) a. Monad m => a -> m a
return (r -> CSEM otherlore r) -> r -> CSEM otherlore r
forall a b. (a -> b) -> a -> b
$ CSEM lore r -> CSEState lore -> r
forall r a. Reader r a -> r -> a
runReader CSEM lore r
m (CSEState lore -> r) -> CSEState lore -> r
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState lore
forall lore. Bool -> CSEState lore
newCSEState Bool
cse_arrays

instance (Attributes lore, Aliased lore,
          CSEInOp (Op lore), CSEInOp op) => CSEInOp (Kernel.HostOp lore op) where
  cseInOp :: HostOp lore op -> CSEM lore (HostOp lore op)
cseInOp (Kernel.SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
Kernel.SegOp (SegOp SegLevel lore -> HostOp lore op)
-> ReaderT (CSEState lore) Identity (SegOp SegLevel lore)
-> CSEM lore (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore
-> ReaderT (CSEState lore) Identity (SegOp SegLevel lore)
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp SegOp SegLevel lore
op
  cseInOp (Kernel.OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
Kernel.OtherOp (op -> HostOp lore op)
-> ReaderT (CSEState lore) Identity op
-> CSEM lore (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> ReaderT (CSEState lore) Identity op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
op
  cseInOp HostOp lore op
x = HostOp lore op -> CSEM lore (HostOp lore op)
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp lore op
x

instance (Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
         CSEInOp (Kernel.SegOp lvl lore) where
  cseInOp :: SegOp lvl lore -> CSEM lore (SegOp lvl lore)
cseInOp = CSEM lore (SegOp lvl lore) -> CSEM lore (SegOp lvl lore)
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (CSEM lore (SegOp lvl lore) -> CSEM lore (SegOp lvl lore))
-> (SegOp lvl lore -> CSEM lore (SegOp lvl lore))
-> SegOp lvl lore
-> CSEM lore (SegOp lvl lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
            SegOpMapper lvl lore lore (ReaderT (CSEState lore) Identity)
-> SegOp lvl lore -> CSEM lore (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
Kernel.mapSegOpM
            ((SubExp -> ReaderT (CSEState lore) Identity SubExp)
-> (Lambda lore -> ReaderT (CSEState lore) Identity (Lambda lore))
-> (KernelBody lore
    -> ReaderT (CSEState lore) Identity (KernelBody lore))
-> (VName -> ReaderT (CSEState lore) Identity VName)
-> (lvl -> ReaderT (CSEState lore) Identity lvl)
-> SegOpMapper lvl lore lore (ReaderT (CSEState lore) Identity)
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
Kernel.SegOpMapper SubExp -> ReaderT (CSEState lore) Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore -> ReaderT (CSEState lore) Identity (Lambda lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Lambda lore -> CSEM lore (Lambda lore)
cseInLambda KernelBody lore
-> ReaderT (CSEState lore) Identity (KernelBody lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
KernelBody lore -> CSEM lore (KernelBody lore)
cseInKernelBody VName -> ReaderT (CSEState lore) Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return lvl -> ReaderT (CSEState lore) Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return)

cseInKernelBody :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
                   Kernel.KernelBody lore -> CSEM lore (Kernel.KernelBody lore)
cseInKernelBody :: KernelBody lore -> CSEM lore (KernelBody lore)
cseInKernelBody (Kernel.KernelBody BodyAttr lore
bodyattr Stms lore
bnds [KernelResult]
res) = do
  Body BodyAttr lore
_ Stms lore
bnds' Result
_ <- [Diet] -> BodyT lore -> CSEM lore (BodyT lore)
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
[Diet] -> Body lore -> CSEM lore (Body lore)
cseInBody ((KernelResult -> Diet) -> [KernelResult] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> KernelResult -> Diet
forall a b. a -> b -> a
const Diet
Observe) [KernelResult]
res) (BodyT lore -> CSEM lore (BodyT lore))
-> BodyT lore -> CSEM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms lore -> Result -> BodyT lore
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr lore
bodyattr Stms lore
bnds []
  KernelBody lore -> CSEM lore (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> CSEM lore (KernelBody lore))
-> KernelBody lore -> CSEM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
Kernel.KernelBody BodyAttr lore
bodyattr Stms lore
bnds' [KernelResult]
res

instance CSEInOp op => CSEInOp (Memory.MemOp op) where
  cseInOp :: MemOp op -> CSEM lore (MemOp op)
cseInOp o :: MemOp op
o@Memory.Alloc{} = MemOp op -> CSEM lore (MemOp op)
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp op
o
  cseInOp (Memory.Inner op
k) = op -> MemOp op
forall inner. inner -> MemOp inner
Memory.Inner (op -> MemOp op)
-> ReaderT (CSEState lore) Identity op -> CSEM lore (MemOp op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CSEM Any op -> ReaderT (CSEState lore) Identity op
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (op -> CSEM Any op
forall op lore. CSEInOp op => op -> CSEM lore op
cseInOp op
k)

instance (Attributes lore,
          CanBeAliased (Op lore),
          CSEInOp (OpWithAliases (Op lore))) =>
         CSEInOp (SOAC.SOAC (Aliases lore)) where
  cseInOp :: SOAC (Aliases lore) -> CSEM lore (SOAC (Aliases lore))
cseInOp = CSEM (Aliases lore) (SOAC (Aliases lore))
-> CSEM lore (SOAC (Aliases lore))
forall lore r otherlore. CSEM lore r -> CSEM otherlore r
subCSE (CSEM (Aliases lore) (SOAC (Aliases lore))
 -> CSEM lore (SOAC (Aliases lore)))
-> (SOAC (Aliases lore)
    -> CSEM (Aliases lore) (SOAC (Aliases lore)))
-> SOAC (Aliases lore)
-> CSEM lore (SOAC (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper
  (Aliases lore)
  (Aliases lore)
  (ReaderT (CSEState (Aliases lore)) Identity)
-> SOAC (Aliases lore) -> CSEM (Aliases lore) (SOAC (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
SOAC.mapSOACM ((SubExp -> ReaderT (CSEState (Aliases lore)) Identity SubExp)
-> (Lambda (Aliases lore)
    -> ReaderT
         (CSEState (Aliases lore)) Identity (Lambda (Aliases lore)))
-> (VName -> ReaderT (CSEState (Aliases lore)) Identity VName)
-> SOACMapper
     (Aliases lore)
     (Aliases lore)
     (ReaderT (CSEState (Aliases lore)) Identity)
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOAC.SOACMapper SubExp -> ReaderT (CSEState (Aliases lore)) Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda (Aliases lore)
-> ReaderT
     (CSEState (Aliases lore)) Identity (Lambda (Aliases lore))
forall lore.
(Attributes lore, Aliased lore, CSEInOp (Op lore)) =>
Lambda lore -> CSEM lore (Lambda lore)
cseInLambda VName -> ReaderT (CSEState (Aliases lore)) Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return)