{-# LANGUAGE UndecidableInstances #-}
module Futhark.Optimise.CSE
( performCSE,
performCSEOnFunDef,
performCSEOnStms,
CSEInOp,
)
where
import Control.Monad.Reader
import Data.Map.Strict qualified as M
import Futhark.Analysis.Alias
import Futhark.IR
import Futhark.IR.Aliases
( Aliases,
consumedInStms,
mkStmsAliases,
removeFunDefAliases,
removeProgAliases,
removeStmAliases,
)
import Futhark.IR.GPU qualified as GPU
import Futhark.IR.MC qualified as MC
import Futhark.IR.Mem qualified as Memory
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS.SOAC qualified as SOAC
import Futhark.Pass
import Futhark.Transform.Substitute
performCSE ::
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool ->
Pass rep rep
performCSE :: forall rep.
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool -> Pass rep rep
performCSE Bool
cse_arrays =
String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"CSE" String
"Combine common subexpressions." ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$ \Prog rep
prog ->
(Prog (Aliases rep) -> Prog rep)
-> PassM (Prog (Aliases rep)) -> PassM (Prog rep)
forall a b. (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases rep) -> Prog rep
forall rep. RephraseOp (OpC rep) => Prog (Aliases rep) -> Prog rep
removeProgAliases
(PassM (Prog (Aliases rep)) -> PassM (Prog rep))
-> (Prog rep -> PassM (Prog (Aliases rep)))
-> Prog rep
-> PassM (Prog rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (Aliases rep) -> PassM (Stms (Aliases rep)))
-> (Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep)))
-> Prog (Aliases rep)
-> PassM (Prog (Aliases rep))
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
(Names -> Stms (Aliases rep) -> PassM (Stms (Aliases rep))
forall {f :: * -> *} {rep}.
(Applicative f, Aliased rep, CSEInOp (OpC rep rep)) =>
Names -> Stms rep -> f (Stms rep)
onConsts ([FunDef rep] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog rep -> [FunDef rep]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog rep
prog)))
Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
forall {f :: * -> *} {rep} {p}.
(Applicative f, Aliased rep, CSEInOp (OpC rep rep)) =>
p -> FunDef rep -> f (FunDef rep)
onFun
(Prog (Aliases rep) -> PassM (Prog (Aliases rep)))
-> (Prog rep -> Prog (Aliases rep))
-> Prog rep
-> PassM (Prog (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog rep -> Prog (Aliases rep)
forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
aliasAnalysis
(Prog rep -> PassM (Prog rep)) -> Prog rep -> PassM (Prog rep)
forall a b. (a -> b) -> a -> b
$ Prog rep
prog
where
onConsts :: Names -> Stms rep -> f (Stms rep)
onConsts Names
free_in_funs Stms rep
stms =
Stms rep -> f (Stms rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> f (Stms rep)) -> Stms rep -> f (Stms rep)
forall a b. (a -> b) -> a -> b
$
(Stms rep, ()) -> Stms rep
forall a b. (a, b) -> a
fst ((Stms rep, ()) -> Stms rep) -> (Stms rep, ()) -> Stms rep
forall a b. (a -> b) -> a -> b
$
Reader (CSEState rep) (Stms rep, ())
-> CSEState rep -> (Stms rep, ())
forall r a. Reader r a -> r -> a
runReader
( Names
-> [Stm rep] -> CSEM rep () -> Reader (CSEState rep) (Stms rep, ())
forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms
(Names
free_in_funs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms rep -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms Stms rep
stms)
(Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)
(() -> CSEM rep ()
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
)
(Bool -> CSEState rep
forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays)
onFun :: p -> FunDef rep -> f (FunDef rep)
onFun p
_ = FunDef rep -> f (FunDef rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef rep -> f (FunDef rep))
-> (FunDef rep -> FunDef rep) -> FunDef rep -> f (FunDef rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef rep -> FunDef rep
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Bool -> FunDef rep -> FunDef rep
cseInFunDef Bool
cse_arrays
performCSEOnFunDef ::
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool ->
FunDef rep ->
FunDef rep
performCSEOnFunDef :: forall rep.
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool -> FunDef rep -> FunDef rep
performCSEOnFunDef Bool
cse_arrays =
FunDef (Aliases rep) -> FunDef rep
forall rep.
RephraseOp (OpC rep) =>
FunDef (Aliases rep) -> FunDef rep
removeFunDefAliases (FunDef (Aliases rep) -> FunDef rep)
-> (FunDef rep -> FunDef (Aliases rep)) -> FunDef rep -> FunDef rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> FunDef (Aliases rep) -> FunDef (Aliases rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Bool -> FunDef rep -> FunDef rep
cseInFunDef Bool
cse_arrays (FunDef (Aliases rep) -> FunDef (Aliases rep))
-> (FunDef rep -> FunDef (Aliases rep))
-> FunDef rep
-> FunDef (Aliases rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef rep -> FunDef (Aliases rep)
forall rep. AliasableRep rep => FunDef rep -> FunDef (Aliases rep)
analyseFun
performCSEOnStms ::
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool ->
Stms rep ->
Stms rep
performCSEOnStms :: forall rep.
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
Bool -> Stms rep -> Stms rep
performCSEOnStms Bool
cse_arrays =
(Stm (Aliases rep) -> Stm rep)
-> Seq (Stm (Aliases rep)) -> Seq (Stm rep)
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases rep) -> Stm rep
forall rep. RephraseOp (OpC rep) => Stm (Aliases rep) -> Stm rep
removeStmAliases (Seq (Stm (Aliases rep)) -> Seq (Stm rep))
-> (Seq (Stm rep) -> Seq (Stm (Aliases rep)))
-> Seq (Stm rep)
-> Seq (Stm rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Stm (Aliases rep)) -> Seq (Stm (Aliases rep))
forall {rep}.
(Aliased rep, CSEInOp (OpC rep rep)) =>
Stms rep -> Stms rep
f (Seq (Stm (Aliases rep)) -> Seq (Stm (Aliases rep)))
-> (Seq (Stm rep) -> Seq (Stm (Aliases rep)))
-> Seq (Stm rep)
-> Seq (Stm (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Seq (Stm (Aliases rep)), AliasesAndConsumed)
-> Seq (Stm (Aliases rep))
forall a b. (a, b) -> a
fst ((Seq (Stm (Aliases rep)), AliasesAndConsumed)
-> Seq (Stm (Aliases rep)))
-> (Seq (Stm rep) -> (Seq (Stm (Aliases rep)), AliasesAndConsumed))
-> Seq (Stm rep)
-> Seq (Stm (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable
-> Seq (Stm rep) -> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms AliasTable
forall a. Monoid a => a
mempty
where
f :: Stms rep -> Stms rep
f Stms rep
stms =
(Stms rep, ()) -> Stms rep
forall a b. (a, b) -> a
fst ((Stms rep, ()) -> Stms rep) -> (Stms rep, ()) -> Stms rep
forall a b. (a -> b) -> a -> b
$
Reader (CSEState rep) (Stms rep, ())
-> CSEState rep -> (Stms rep, ())
forall r a. Reader r a -> r -> a
runReader
( Names
-> [Stm rep] -> CSEM rep () -> Reader (CSEState rep) (Stms rep, ())
forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms
(Stms rep -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms Stms rep
stms)
(Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)
(() -> CSEM rep ()
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
)
(Bool -> CSEState rep
forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays)
cseInFunDef ::
(Aliased rep, CSEInOp (Op rep)) =>
Bool ->
FunDef rep ->
FunDef rep
cseInFunDef :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Bool -> FunDef rep -> FunDef rep
cseInFunDef Bool
cse_arrays FunDef rep
fundec =
FunDef rep
fundec
{ funDefBody :: Body rep
funDefBody =
Reader (CSEState rep) (Body rep) -> CSEState rep -> Body rep
forall r a. Reader r a -> r -> a
runReader ([Diet] -> Body rep -> Reader (CSEState rep) (Body rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody [Diet]
ds (Body rep -> Reader (CSEState rep) (Body rep))
-> Body rep -> Reader (CSEState rep) (Body rep)
forall a b. (a -> b) -> a -> b
$ FunDef rep -> Body rep
forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
fundec) (CSEState rep -> Body rep) -> CSEState rep -> Body rep
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState rep
forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays
}
where
ds :: [Diet]
ds = ((RetType rep, RetAls) -> Diet)
-> [(RetType rep, RetAls)] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (RetType rep -> Diet
forall {t}. DeclExtTyped t => t -> Diet
retDiet (RetType rep -> Diet)
-> ((RetType rep, RetAls) -> RetType rep)
-> (RetType rep, RetAls)
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RetType rep, RetAls) -> RetType rep
forall a b. (a, b) -> a
fst) ([(RetType rep, RetAls)] -> [Diet])
-> [(RetType rep, RetAls)] -> [Diet]
forall a b. (a -> b) -> a -> b
$ FunDef rep -> [(RetType rep, RetAls)]
forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef rep
fundec
retDiet :: t -> Diet
retDiet t
t
| TypeBase ExtShape Uniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase ExtShape Uniqueness -> Bool)
-> TypeBase ExtShape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ t -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf t
t = Diet
Observe
| Bool
otherwise = Diet
Consume
type CSEM rep = Reader (CSEState rep)
cseInBody ::
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] ->
Body rep ->
CSEM rep (Body rep)
cseInBody :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody [Diet]
ds (Body BodyDec rep
bodydec Stms rep
stms Result
res) = do
(Stms rep
stms', Result
res') <-
Names
-> [Stm rep] -> CSEM rep Result -> CSEM rep (Stms rep, Result)
forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms (Names
res_cons Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
stms_cons) (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms) (CSEM rep Result -> CSEM rep (Stms rep, Result))
-> CSEM rep Result -> CSEM rep (Stms rep, Result)
forall a b. (a -> b) -> a -> b
$ do
CSEState (ExpressionSubstitutions rep
_, NameSubstitutions
nsubsts) Bool
_ <- ReaderT (CSEState rep) Identity (CSEState rep)
forall r (m :: * -> *). MonadReader r m => m r
ask
Result -> CSEM rep Result
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> CSEM rep Result) -> Result -> CSEM rep Result
forall a b. (a -> b) -> a -> b
$ NameSubstitutions -> Result -> Result
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Result
res
Body rep -> CSEM rep (Body rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> CSEM rep (Body rep))
-> Body rep -> CSEM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
bodydec Stms rep
stms' Result
res'
where
([Names]
res_als, Names
stms_cons) = Stms rep -> Result -> ([Names], Names)
forall rep. Aliased rep => Stms rep -> Result -> ([Names], Names)
mkStmsAliases Stms rep
stms Result
res
res_cons :: Names
res_cons = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Diet -> Names -> Names) -> [Diet] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Diet -> Names -> Names
forall {p}. Monoid p => Diet -> p -> p
consumeResult [Diet]
ds [Names]
res_als
consumeResult :: Diet -> p -> p
consumeResult Diet
Consume p
als = p
als
consumeResult Diet
_ p
_ = p
forall a. Monoid a => a
mempty
cseInLambda ::
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep ->
CSEM rep (Lambda rep)
cseInLambda :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep -> CSEM rep (Lambda rep)
cseInLambda Lambda rep
lam = do
Body rep
body' <- [Diet] -> Body rep -> CSEM rep (Body rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) ([Type] -> [Diet]) -> [Type] -> [Diet]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) (Body rep -> CSEM rep (Body rep))
-> Body rep -> CSEM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
Lambda rep -> CSEM rep (Lambda rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body'}
cseInStms ::
forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names ->
[Stm rep] ->
CSEM rep a ->
CSEM rep (Stms rep, a)
cseInStms :: forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms Names
_ [] CSEM rep a
m = do
a
a <- CSEM rep a
m
(Stms rep, a) -> CSEM rep (Stms rep, a)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
forall a. Monoid a => a
mempty, a
a)
cseInStms Names
consumed (Stm rep
stm : [Stm rep]
stms) CSEM rep a
m =
Names
-> Stm rep
-> ([Stm rep] -> CSEM rep (Stms rep, a))
-> CSEM rep (Stms rep, a)
forall rep a.
ASTRep rep =>
Names -> Stm rep -> ([Stm rep] -> CSEM rep a) -> CSEM rep a
cseInStm Names
consumed Stm rep
stm (([Stm rep] -> CSEM rep (Stms rep, a)) -> CSEM rep (Stms rep, a))
-> ([Stm rep] -> CSEM rep (Stms rep, a)) -> CSEM rep (Stms rep, a)
forall a b. (a -> b) -> a -> b
$ \[Stm rep]
stm' -> do
(Stms rep
stms', a
a) <- Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
forall rep a.
(Aliased rep, CSEInOp (Op rep)) =>
Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a)
cseInStms Names
consumed [Stm rep]
stms CSEM rep a
m
[Stm rep]
stm'' <- (Stm rep -> ReaderT (CSEState rep) Identity (Stm rep))
-> [Stm rep] -> ReaderT (CSEState rep) Identity [Stm rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm rep -> ReaderT (CSEState rep) Identity (Stm rep)
nestedCSE [Stm rep]
stm'
(Stms rep, a) -> CSEM rep (Stms rep, a)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
stm'' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms', a
a)
where
nestedCSE :: Stm rep -> ReaderT (CSEState rep) Identity (Stm rep)
nestedCSE Stm rep
stm' = do
let ds :: [Diet]
ds =
case Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm' of
DoLoop [(Param (FParamInfo rep), SubExp)]
merge LoopForm rep
_ Body rep
_ -> ((Param (FParamInfo rep), SubExp) -> Diet)
-> [(Param (FParamInfo rep), SubExp)] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase Shape Uniqueness -> Diet)
-> ((Param (FParamInfo rep), SubExp) -> TypeBase Shape Uniqueness)
-> (Param (FParamInfo rep), SubExp)
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo rep) -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf (Param (FParamInfo rep) -> TypeBase Shape Uniqueness)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
merge
Exp rep
_ -> (PatElem (LetDec rep) -> Diet) -> [PatElem (LetDec rep)] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (LetDec rep) -> Diet
forall {dec}. PatElem dec -> Diet
patElemDiet ([PatElem (LetDec rep)] -> [Diet])
-> [PatElem (LetDec rep)] -> [Diet]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat (LetDec rep) -> [PatElem (LetDec rep)])
-> Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm'
Exp rep
e <- Mapper rep rep (ReaderT (CSEState rep) Identity)
-> Exp rep -> ReaderT (CSEState rep) Identity (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ([Diet] -> Mapper rep rep (ReaderT (CSEState rep) Identity)
cse [Diet]
ds) (Exp rep -> ReaderT (CSEState rep) Identity (Exp rep))
-> Exp rep -> ReaderT (CSEState rep) Identity (Exp rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm'
Stm rep -> ReaderT (CSEState rep) Identity (Stm rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm rep
stm' {stmExp :: Exp rep
stmExp = Exp rep
e}
cse :: [Diet] -> Mapper rep rep (ReaderT (CSEState rep) Identity)
cse [Diet]
ds =
(forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
{ mapOnBody :: Scope rep -> Body rep -> ReaderT (CSEState rep) Identity (Body rep)
mapOnBody = (Body rep -> ReaderT (CSEState rep) Identity (Body rep))
-> Scope rep
-> Body rep
-> ReaderT (CSEState rep) Identity (Body rep)
forall a b. a -> b -> a
const ((Body rep -> ReaderT (CSEState rep) Identity (Body rep))
-> Scope rep
-> Body rep
-> ReaderT (CSEState rep) Identity (Body rep))
-> (Body rep -> ReaderT (CSEState rep) Identity (Body rep))
-> Scope rep
-> Body rep
-> ReaderT (CSEState rep) Identity (Body rep)
forall a b. (a -> b) -> a -> b
$ [Diet] -> Body rep -> ReaderT (CSEState rep) Identity (Body rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody [Diet]
ds,
mapOnOp :: Op rep -> ReaderT (CSEState rep) Identity (Op rep)
mapOnOp = Op rep -> ReaderT (CSEState rep) Identity (Op rep)
forall rep. Op rep -> CSEM rep (Op rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp
}
patElemDiet :: PatElem dec -> Diet
patElemDiet PatElem dec
pe
| PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Diet
Consume
| Bool
otherwise = Diet
Observe
normExp :: Exp lore -> Exp lore
normExp :: forall lore. Exp lore -> Exp lore
normExp (Apply Name
fname [(SubExp, Diet)]
args [(RetType lore, RetAls)]
ret (Safety
safety, SrcLoc
_, [SrcLoc]
_)) =
Name
-> [(SubExp, Diet)]
-> [(RetType lore, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp lore
forall rep.
Name
-> [(SubExp, Diet)]
-> [(RetType rep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args [(RetType lore, RetAls)]
ret (Safety
safety, SrcLoc
forall a. Monoid a => a
mempty, [SrcLoc]
forall a. Monoid a => a
mempty)
normExp Exp lore
e = Exp lore
e
cseInStm ::
ASTRep rep =>
Names ->
Stm rep ->
([Stm rep] -> CSEM rep a) ->
CSEM rep a
cseInStm :: forall rep a.
ASTRep rep =>
Names -> Stm rep -> ([Stm rep] -> CSEM rep a) -> CSEM rep a
cseInStm Names
consumed (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
edec) Exp rep
e) [Stm rep] -> CSEM rep a
m = do
CSEState (ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays <- ReaderT (CSEState rep) Identity (CSEState rep)
forall r (m :: * -> *). MonadReader r m => m r
ask
let e' :: Exp rep
e' = Exp rep -> Exp rep
forall lore. Exp lore -> Exp lore
normExp (Exp rep -> Exp rep) -> Exp rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ NameSubstitutions -> Exp rep -> Exp rep
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Exp rep
e
pat' :: Pat (LetDec rep)
pat' = NameSubstitutions -> Pat (LetDec rep) -> Pat (LetDec rep)
forall a. Substitute a => NameSubstitutions -> a -> a
substituteNames NameSubstitutions
nsubsts Pat (LetDec rep)
pat
if Bool -> Bool
not (Exp rep -> Bool
forall {rep}. Exp rep -> Bool
alreadyAliases Exp rep
e) Bool -> Bool -> Bool
&& (PatElem (LetDec rep) -> Bool) -> [PatElem (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> PatElem (LetDec rep) -> Bool
forall {dec}. Typed dec => Bool -> PatElem dec -> Bool
bad Bool
cse_arrays) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
then [Stm rep] -> CSEM rep a
m [Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
edec) Exp rep
e']
else case (ExpDec rep, Exp rep)
-> ExpressionSubstitutions rep -> Maybe (Certs, Pat (LetDec rep))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ExpDec rep
edec, Exp rep
e') ExpressionSubstitutions rep
esubsts of
Just (Certs
subcs, Pat (LetDec rep)
subpat) -> do
let subsumes :: Bool
subsumes = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Certs -> [VName]
unCerts Certs
subcs) (Certs -> [VName]
unCerts Certs
cs)
(CSEState rep -> CSEState rep) -> CSEM rep a -> CSEM rep a
forall a.
(CSEState rep -> CSEState rep)
-> ReaderT (CSEState rep) Identity a
-> ReaderT (CSEState rep) Identity a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (if Bool
subsumes then Pat (LetDec rep)
-> Pat (LetDec rep) -> CSEState rep -> CSEState rep
forall dec rep. Pat dec -> Pat dec -> CSEState rep -> CSEState rep
addNameSubst Pat (LetDec rep)
pat' Pat (LetDec rep)
subpat else CSEState rep -> CSEState rep
forall a. a -> a
id) (CSEM rep a -> CSEM rep a) -> CSEM rep a -> CSEM rep a
forall a b. (a -> b) -> a -> b
$ do
let lets :: [Stm rep]
lets =
[ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
patElem']) (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
edec) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
patElem)
| (VName
name, PatElem (LetDec rep)
patElem) <- [VName]
-> [PatElem (LetDec rep)] -> [(VName, PatElem (LetDec rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat') ([PatElem (LetDec rep)] -> [(VName, PatElem (LetDec rep))])
-> [PatElem (LetDec rep)] -> [(VName, PatElem (LetDec rep))]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
subpat,
let patElem' :: PatElem (LetDec rep)
patElem' = PatElem (LetDec rep)
patElem {patElemName :: VName
patElemName = VName
name}
]
[Stm rep] -> CSEM rep a
m [Stm rep]
lets
Maybe (Certs, Pat (LetDec rep))
_ ->
(CSEState rep -> CSEState rep) -> CSEM rep a -> CSEM rep a
forall a.
(CSEState rep -> CSEState rep)
-> ReaderT (CSEState rep) Identity a
-> ReaderT (CSEState rep) Identity a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Pat (LetDec rep)
-> ExpDec rep -> Certs -> Exp rep -> CSEState rep -> CSEState rep
forall rep.
ASTRep rep =>
Pat (LetDec rep)
-> ExpDec rep -> Certs -> Exp rep -> CSEState rep -> CSEState rep
addExpSubst Pat (LetDec rep)
pat' ExpDec rep
edec Certs
cs Exp rep
e') (CSEM rep a -> CSEM rep a) -> CSEM rep a -> CSEM rep a
forall a b. (a -> b) -> a -> b
$
[Stm rep] -> CSEM rep a
m [Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
edec) Exp rep
e']
where
alreadyAliases :: Exp rep -> Bool
alreadyAliases (BasicOp Index {}) = Bool
True
alreadyAliases (BasicOp Reshape {}) = Bool
True
alreadyAliases Exp rep
_ = Bool
False
bad :: Bool -> PatElem dec -> Bool
bad Bool
cse_arrays PatElem dec
pe
| Mem {} <- PatElem dec -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe = Bool
True
| Array {} <- PatElem dec -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe, Bool -> Bool
not Bool
cse_arrays = Bool
True
| PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe VName -> Names -> Bool
`nameIn` Names
consumed = Bool
True
| Bool
otherwise = Bool
False
type ExpressionSubstitutions rep =
M.Map
(ExpDec rep, Exp rep)
(Certs, Pat (LetDec rep))
type NameSubstitutions = M.Map VName VName
data CSEState rep = CSEState
{ forall rep.
CSEState rep -> (ExpressionSubstitutions rep, NameSubstitutions)
_cseSubstitutions :: (ExpressionSubstitutions rep, NameSubstitutions),
forall rep. CSEState rep -> Bool
_cseArrays :: Bool
}
newCSEState :: Bool -> CSEState rep
newCSEState :: forall rep. Bool -> CSEState rep
newCSEState = (ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
forall rep.
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
CSEState (ExpressionSubstitutions rep
forall k a. Map k a
M.empty, NameSubstitutions
forall k a. Map k a
M.empty)
mkSubsts :: Pat dec -> Pat dec -> M.Map VName VName
mkSubsts :: forall dec. Pat dec -> Pat dec -> NameSubstitutions
mkSubsts Pat dec
pat Pat dec
vs = [(VName, VName)] -> NameSubstitutions
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> NameSubstitutions)
-> [(VName, VName)] -> NameSubstitutions
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat dec -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat dec
pat) (Pat dec -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat dec
vs)
addNameSubst :: Pat dec -> Pat dec -> CSEState rep -> CSEState rep
addNameSubst :: forall dec rep. Pat dec -> Pat dec -> CSEState rep -> CSEState rep
addNameSubst Pat dec
pat Pat dec
subpat (CSEState (ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
forall rep.
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
CSEState (ExpressionSubstitutions rep
esubsts, Pat dec -> Pat dec -> NameSubstitutions
forall dec. Pat dec -> Pat dec -> NameSubstitutions
mkSubsts Pat dec
pat Pat dec
subpat NameSubstitutions -> NameSubstitutions -> NameSubstitutions
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` NameSubstitutions
nsubsts) Bool
cse_arrays
addExpSubst ::
ASTRep rep =>
Pat (LetDec rep) ->
ExpDec rep ->
Certs ->
Exp rep ->
CSEState rep ->
CSEState rep
addExpSubst :: forall rep.
ASTRep rep =>
Pat (LetDec rep)
-> ExpDec rep -> Certs -> Exp rep -> CSEState rep -> CSEState rep
addExpSubst Pat (LetDec rep)
pat ExpDec rep
edec Certs
cs Exp rep
e (CSEState (ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays) =
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
forall rep.
(ExpressionSubstitutions rep, NameSubstitutions)
-> Bool -> CSEState rep
CSEState ((ExpDec rep, Exp rep)
-> (Certs, Pat (LetDec rep))
-> ExpressionSubstitutions rep
-> ExpressionSubstitutions rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (ExpDec rep
edec, Exp rep
e) (Certs
cs, Pat (LetDec rep)
pat) ExpressionSubstitutions rep
esubsts, NameSubstitutions
nsubsts) Bool
cse_arrays
class CSEInOp op where
cseInOp :: op -> CSEM rep op
instance CSEInOp (NoOp rep) where
cseInOp :: forall rep. NoOp rep -> CSEM rep (NoOp rep)
cseInOp NoOp rep
NoOp = NoOp rep -> ReaderT (CSEState rep) Identity (NoOp rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoOp rep
forall {k} (rep :: k). NoOp rep
NoOp
subCSE :: CSEM rep r -> CSEM otherrep r
subCSE :: forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE CSEM rep r
m = do
CSEState (ExpressionSubstitutions otherrep, NameSubstitutions)
_ Bool
cse_arrays <- ReaderT (CSEState otherrep) Identity (CSEState otherrep)
forall r (m :: * -> *). MonadReader r m => m r
ask
r -> CSEM otherrep r
forall a. a -> ReaderT (CSEState otherrep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (r -> CSEM otherrep r) -> r -> CSEM otherrep r
forall a b. (a -> b) -> a -> b
$ CSEM rep r -> CSEState rep -> r
forall r a. Reader r a -> r -> a
runReader CSEM rep r
m (CSEState rep -> r) -> CSEState rep -> r
forall a b. (a -> b) -> a -> b
$ Bool -> CSEState rep
forall rep. Bool -> CSEState rep
newCSEState Bool
cse_arrays
instance
( Aliased rep,
CSEInOp (Op rep),
CSEInOp (op rep)
) =>
CSEInOp (GPU.HostOp op rep)
where
cseInOp :: forall rep. HostOp op rep -> CSEM rep (HostOp op rep)
cseInOp (GPU.SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> HostOp op rep
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
GPU.SegOp (SegOp SegLevel rep -> HostOp op rep)
-> ReaderT (CSEState rep) Identity (SegOp SegLevel rep)
-> ReaderT (CSEState rep) Identity (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel rep
-> ReaderT (CSEState rep) Identity (SegOp SegLevel rep)
forall rep. SegOp SegLevel rep -> CSEM rep (SegOp SegLevel rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp SegOp SegLevel rep
op
cseInOp (GPU.OtherOp op rep
op) = op rep -> HostOp op rep
forall (op :: * -> *) rep. op rep -> HostOp op rep
GPU.OtherOp (op rep -> HostOp op rep)
-> ReaderT (CSEState rep) Identity (op rep)
-> ReaderT (CSEState rep) Identity (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op rep -> ReaderT (CSEState rep) Identity (op rep)
forall rep. op rep -> CSEM rep (op rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp op rep
op
cseInOp (GPU.GPUBody [Type]
types Body rep
body) =
CSEM rep (HostOp op rep)
-> ReaderT (CSEState rep) Identity (HostOp op rep)
forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE (CSEM rep (HostOp op rep)
-> ReaderT (CSEState rep) Identity (HostOp op rep))
-> CSEM rep (HostOp op rep)
-> ReaderT (CSEState rep) Identity (HostOp op rep)
forall a b. (a -> b) -> a -> b
$ [Type] -> Body rep -> HostOp op rep
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPU.GPUBody [Type]
types (Body rep -> HostOp op rep)
-> ReaderT (CSEState rep) Identity (Body rep)
-> CSEM rep (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Diet] -> Body rep -> ReaderT (CSEState rep) Identity (Body rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) [Type]
types) Body rep
body
cseInOp HostOp op rep
x = HostOp op rep -> ReaderT (CSEState rep) Identity (HostOp op rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp op rep
x
instance
( Aliased rep,
CSEInOp (Op rep),
CSEInOp (op rep)
) =>
CSEInOp (MC.MCOp op rep)
where
cseInOp :: forall rep. MCOp op rep -> CSEM rep (MCOp op rep)
cseInOp (MC.ParOp Maybe (SegOp () rep)
par_op SegOp () rep
op) =
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
MC.ParOp (Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep)
-> ReaderT (CSEState rep) Identity (Maybe (SegOp () rep))
-> ReaderT (CSEState rep) Identity (SegOp () rep -> MCOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () rep -> ReaderT (CSEState rep) Identity (SegOp () rep))
-> Maybe (SegOp () rep)
-> ReaderT (CSEState rep) Identity (Maybe (SegOp () rep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse SegOp () rep -> ReaderT (CSEState rep) Identity (SegOp () rep)
forall rep. SegOp () rep -> CSEM rep (SegOp () rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp Maybe (SegOp () rep)
par_op ReaderT (CSEState rep) Identity (SegOp () rep -> MCOp op rep)
-> ReaderT (CSEState rep) Identity (SegOp () rep)
-> ReaderT (CSEState rep) Identity (MCOp op rep)
forall a b.
ReaderT (CSEState rep) Identity (a -> b)
-> ReaderT (CSEState rep) Identity a
-> ReaderT (CSEState rep) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOp () rep -> ReaderT (CSEState rep) Identity (SegOp () rep)
forall rep. SegOp () rep -> CSEM rep (SegOp () rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp SegOp () rep
op
cseInOp (MC.OtherOp op rep
op) =
op rep -> MCOp op rep
forall (op :: * -> *) rep. op rep -> MCOp op rep
MC.OtherOp (op rep -> MCOp op rep)
-> ReaderT (CSEState rep) Identity (op rep)
-> ReaderT (CSEState rep) Identity (MCOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op rep -> ReaderT (CSEState rep) Identity (op rep)
forall rep. op rep -> CSEM rep (op rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp op rep
op
instance
(Aliased rep, CSEInOp (Op rep)) =>
CSEInOp (GPU.SegOp lvl rep)
where
cseInOp :: forall rep. SegOp lvl rep -> CSEM rep (SegOp lvl rep)
cseInOp =
CSEM rep (SegOp lvl rep) -> CSEM rep (SegOp lvl rep)
forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE
(CSEM rep (SegOp lvl rep) -> CSEM rep (SegOp lvl rep))
-> (SegOp lvl rep -> CSEM rep (SegOp lvl rep))
-> SegOp lvl rep
-> CSEM rep (SegOp lvl rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lvl rep rep (ReaderT (CSEState rep) Identity)
-> SegOp lvl rep -> CSEM rep (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
GPU.mapSegOpM
((SubExp -> ReaderT (CSEState rep) Identity SubExp)
-> (Lambda rep -> ReaderT (CSEState rep) Identity (Lambda rep))
-> (KernelBody rep
-> ReaderT (CSEState rep) Identity (KernelBody rep))
-> (VName -> ReaderT (CSEState rep) Identity VName)
-> (lvl -> ReaderT (CSEState rep) Identity lvl)
-> SegOpMapper lvl rep rep (ReaderT (CSEState rep) Identity)
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
GPU.SegOpMapper SubExp -> ReaderT (CSEState rep) Identity SubExp
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep -> ReaderT (CSEState rep) Identity (Lambda rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep -> CSEM rep (Lambda rep)
cseInLambda KernelBody rep -> ReaderT (CSEState rep) Identity (KernelBody rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
KernelBody rep -> CSEM rep (KernelBody rep)
cseInKernelBody VName -> ReaderT (CSEState rep) Identity VName
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure lvl -> ReaderT (CSEState rep) Identity lvl
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
cseInKernelBody ::
(Aliased rep, CSEInOp (Op rep)) =>
GPU.KernelBody rep ->
CSEM rep (GPU.KernelBody rep)
cseInKernelBody :: forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
KernelBody rep -> CSEM rep (KernelBody rep)
cseInKernelBody (GPU.KernelBody BodyDec rep
bodydec Stms rep
stms [KernelResult]
res) = do
Body BodyDec rep
_ Stms rep
stms' Result
_ <- [Diet] -> Body rep -> CSEM rep (Body rep)
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
[Diet] -> Body rep -> CSEM rep (Body rep)
cseInBody ((KernelResult -> Diet) -> [KernelResult] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> KernelResult -> Diet
forall a b. a -> b -> a
const Diet
Observe) [KernelResult]
res) (Body rep -> CSEM rep (Body rep))
-> Body rep -> CSEM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
bodydec Stms rep
stms []
KernelBody rep -> CSEM rep (KernelBody rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> CSEM rep (KernelBody rep))
-> KernelBody rep -> CSEM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
GPU.KernelBody BodyDec rep
bodydec Stms rep
stms' [KernelResult]
res
instance CSEInOp (op rep) => CSEInOp (Memory.MemOp op rep) where
cseInOp :: forall rep. MemOp op rep -> CSEM rep (MemOp op rep)
cseInOp o :: MemOp op rep
o@Memory.Alloc {} = MemOp op rep -> ReaderT (CSEState rep) Identity (MemOp op rep)
forall a. a -> ReaderT (CSEState rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp op rep
o
cseInOp (Memory.Inner op rep
k) = op rep -> MemOp op rep
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Memory.Inner (op rep -> MemOp op rep)
-> ReaderT (CSEState rep) Identity (op rep)
-> ReaderT (CSEState rep) Identity (MemOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CSEM Any (op rep) -> ReaderT (CSEState rep) Identity (op rep)
forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE (op rep -> CSEM Any (op rep)
forall rep. op rep -> CSEM rep (op rep)
forall op rep. CSEInOp op => op -> CSEM rep op
cseInOp op rep
k)
instance
(AliasableRep rep, CSEInOp (Op (Aliases rep))) =>
CSEInOp (SOAC.SOAC (Aliases rep))
where
cseInOp :: forall rep. SOAC (Aliases rep) -> CSEM rep (SOAC (Aliases rep))
cseInOp = CSEM (Aliases rep) (SOAC (Aliases rep))
-> CSEM rep (SOAC (Aliases rep))
forall rep r otherrep. CSEM rep r -> CSEM otherrep r
subCSE (CSEM (Aliases rep) (SOAC (Aliases rep))
-> CSEM rep (SOAC (Aliases rep)))
-> (SOAC (Aliases rep) -> CSEM (Aliases rep) (SOAC (Aliases rep)))
-> SOAC (Aliases rep)
-> CSEM rep (SOAC (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper
(Aliases rep)
(Aliases rep)
(ReaderT (CSEState (Aliases rep)) Identity)
-> SOAC (Aliases rep) -> CSEM (Aliases rep) (SOAC (Aliases rep))
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
SOAC.mapSOACM ((SubExp -> ReaderT (CSEState (Aliases rep)) Identity SubExp)
-> (Lambda (Aliases rep)
-> ReaderT
(CSEState (Aliases rep)) Identity (Lambda (Aliases rep)))
-> (VName -> ReaderT (CSEState (Aliases rep)) Identity VName)
-> SOACMapper
(Aliases rep)
(Aliases rep)
(ReaderT (CSEState (Aliases rep)) Identity)
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOAC.SOACMapper SubExp -> ReaderT (CSEState (Aliases rep)) Identity SubExp
forall a. a -> ReaderT (CSEState (Aliases rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Aliases rep)
-> ReaderT (CSEState (Aliases rep)) Identity (Lambda (Aliases rep))
forall rep.
(Aliased rep, CSEInOp (Op rep)) =>
Lambda rep -> CSEM rep (Lambda rep)
cseInLambda VName -> ReaderT (CSEState (Aliases rep)) Identity VName
forall a. a -> ReaderT (CSEState (Aliases rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)