{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Sink (sink) where
import Control.Monad.State
import Data.List (foldl')
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.Pass
type SinkLore = Aliases Kernels
type SymbolTable = ST.SymbolTable SinkLore
type Sinking = M.Map VName (Stm SinkLore)
type Sunk = S.Set VName
multiplicity :: Stm SinkLore -> M.Map VName Int
multiplicity :: Stm SinkLore -> Map VName Int
multiplicity Stm SinkLore
stm =
case Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm of
If SubExp
cond BodyT SinkLore
tbranch BodyT SinkLore
fbranch IfDec (BranchType SinkLore)
_ ->
SubExp -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free SubExp
cond Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT SinkLore
tbranch Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT SinkLore
fbranch Int
1
Op {} -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
2
DoLoop {} -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
2
Exp SinkLore
_ -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
1
where
free :: a -> a -> Map VName a
free a
x a
k = [(VName, a)] -> Map VName a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, a)] -> Map VName a) -> [(VName, a)] -> Map VName a
forall a b. (a -> b) -> a -> b
$ [VName] -> [a] -> [(VName, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x) ([a] -> [(VName, a)]) -> [a] -> [(VName, a)]
forall a b. (a -> b) -> a -> b
$ a -> [a]
forall a. a -> [a]
repeat a
k
comb :: Map VName Int -> Map VName Int -> Map VName Int
comb = (Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)
optimiseBranch ::
SymbolTable ->
Sinking ->
Body SinkLore ->
(Body SinkLore, Sunk)
optimiseBranch :: SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking (Body BodyDec SinkLore
dec Stms SinkLore
stms Result
res) =
let (Stms SinkLore
stms', Sunk
stms_sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking' Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
in ( BodyDec SinkLore -> Stms SinkLore -> Result -> BodyT SinkLore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SinkLore
dec (Stms SinkLore
sunk_stms Stms SinkLore -> Stms SinkLore -> Stms SinkLore
forall a. Semigroup a => a -> a -> a
<> Stms SinkLore
stms') Result
res,
Sunk
sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
)
where
free_in_stms :: Names
free_in_stms = Stms SinkLore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms SinkLore
stms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
(Sinking
sinking_here, Sinking
sinking') = (VName -> Stm SinkLore -> Bool) -> Sinking -> (Sinking, Sinking)
forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm SinkLore -> Bool
sunkHere Sinking
sinking
sunk_stms :: Stms SinkLore
sunk_stms = [Stm SinkLore] -> Stms SinkLore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm SinkLore] -> Stms SinkLore)
-> [Stm SinkLore] -> Stms SinkLore
forall a b. (a -> b) -> a -> b
$ Sinking -> [Stm SinkLore]
forall k a. Map k a -> [a]
M.elems Sinking
sinking_here
sunkHere :: VName -> Stm SinkLore -> Bool
sunkHere VName
v Stm SinkLore
stm =
VName
v VName -> Names -> Bool
`nameIn` Names
free_in_stms
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.available` SymbolTable
vtable) (Names -> [VName]
namesToList (Stm SinkLore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SinkLore
stm))
sunk :: Sunk
sunk = [VName] -> Sunk
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Sunk) -> [VName] -> Sunk
forall a b. (a -> b) -> a -> b
$ (Stm SinkLore -> [VName]) -> Stms SinkLore -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT (VarAliases, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (VarAliases, Type) -> [VName])
-> (Stm SinkLore -> PatternT (VarAliases, Type))
-> Stm SinkLore
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SinkLore -> PatternT (VarAliases, Type)
forall lore. Stm lore -> Pattern lore
stmPattern) Stms SinkLore
sunk_stms
optimiseStms ::
SymbolTable ->
Sinking ->
Stms SinkLore ->
Names ->
(Stms SinkLore, Sunk)
optimiseStms :: SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
init_vtable Sinking
init_sinking Stms SinkLore
all_stms Names
free_in_res =
let ([Stm SinkLore]
all_stms', Sunk
sunk) =
SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
init_vtable Sinking
init_sinking ([Stm SinkLore] -> ([Stm SinkLore], Sunk))
-> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> [Stm SinkLore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SinkLore
all_stms
in ([Stm SinkLore] -> Stms SinkLore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm SinkLore]
all_stms', Sunk
sunk)
where
multiplicities :: Map VName Int
multiplicities =
(Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
([(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList Names
free_in_res) (Int -> [Int]
forall a. a -> [a]
repeat Int
1)))
((Stm SinkLore -> Map VName Int)
-> [Stm SinkLore] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map Stm SinkLore -> Map VName Int
multiplicity ([Stm SinkLore] -> [Map VName Int])
-> [Stm SinkLore] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> [Stm SinkLore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SinkLore
all_stms)
optimiseStms' :: SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
_ Sinking
_ [] = ([], Sunk
forall a. Monoid a => a
mempty)
optimiseStms' SymbolTable
vtable Sinking
sinking (Stm SinkLore
stm : [Stm SinkLore]
stms)
| BasicOp Index {} <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm,
[PatElemT (VarAliases, Type)
pe] <- PatternT (VarAliases, Type) -> [PatElemT (VarAliases, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (Stm SinkLore -> Pattern SinkLore
forall lore. Stm lore -> Pattern lore
stmPattern Stm SinkLore
stm),
Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarAliases, Type) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (VarAliases, Type)
pe,
Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe) Map VName Int
multiplicities =
let ([Stm SinkLore]
stms', Sunk
sunk) =
SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' (VName -> Stm SinkLore -> Sinking -> Sinking
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe) Stm SinkLore
stm Sinking
sinking) [Stm SinkLore]
stms
in if PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe VName -> Sunk -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Sunk
sunk
then ([Stm SinkLore]
stms', Sunk
sunk)
else (Stm SinkLore
stm Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms', Sunk
sunk)
| If SubExp
cond BodyT SinkLore
tbranch BodyT SinkLore
fbranch IfDec (BranchType SinkLore)
ret <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm =
let (BodyT SinkLore
tbranch', Sunk
tsunk) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking BodyT SinkLore
tbranch
(BodyT SinkLore
fbranch', Sunk
fsunk) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking BodyT SinkLore
fbranch
([Stm SinkLore]
stms', Sunk
sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
in ( Stm SinkLore
stm {stmExp :: Exp SinkLore
stmExp = SubExp
-> BodyT SinkLore
-> BodyT SinkLore
-> IfDec (BranchType SinkLore)
-> Exp SinkLore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT SinkLore
tbranch' BodyT SinkLore
fbranch' IfDec (BranchType SinkLore)
ret} Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
Sunk
tsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
fsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk
)
| Op (SegOp op) <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm =
let scope :: Scope SinkLore
scope = SegSpace -> Scope SinkLore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope SinkLore) -> SegSpace -> Scope SinkLore
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel SinkLore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel SinkLore
op
([Stm SinkLore]
stms', Sunk
stms_sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
(SegOp SegLevel SinkLore
op', Sunk
op_sunk) = State Sunk (SegOp SegLevel SinkLore)
-> Sunk -> (SegOp SegLevel SinkLore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
-> SegOp SegLevel SinkLore -> State Sunk (SegOp SegLevel SinkLore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (Scope SinkLore
-> SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
opMapper Scope SinkLore
scope) SegOp SegLevel SinkLore
op) Sunk
forall a. Monoid a => a
mempty
in ( Stm SinkLore
stm {stmExp :: Exp SinkLore
stmExp = Op SinkLore -> Exp SinkLore
forall lore. Op lore -> ExpT lore
Op (SegOp SegLevel SinkLore -> HostOp SinkLore (SOAC SinkLore)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel SinkLore
op')} Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
Sunk
stms_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
op_sunk
)
| Bool
otherwise =
let ([Stm SinkLore]
stms', Sunk
stms_sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
(Exp SinkLore
e', Sunk
stm_sunk) = State Sunk (Exp SinkLore) -> Sunk -> (Exp SinkLore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (Mapper SinkLore SinkLore (StateT Sunk Identity)
-> Exp SinkLore -> State Sunk (Exp SinkLore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SinkLore SinkLore (StateT Sunk Identity)
mapper (Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm)) Sunk
forall a. Monoid a => a
mempty
in ( Stm SinkLore
stm {stmExp :: Exp SinkLore
stmExp = Exp SinkLore
e'} Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
Sunk
stm_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
)
where
vtable' :: SymbolTable
vtable' = Stm SinkLore -> SymbolTable -> SymbolTable
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm SinkLore
stm SymbolTable
vtable
mapper :: Mapper SinkLore SinkLore (StateT Sunk Identity)
mapper =
Mapper SinkLore SinkLore (StateT Sunk Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
{ mapOnBody :: Scope SinkLore
-> BodyT SinkLore -> StateT Sunk Identity (BodyT SinkLore)
mapOnBody = \Scope SinkLore
scope BodyT SinkLore
body -> do
let (BodyT SinkLore
body', Sunk
sunk) =
SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody (Scope SinkLore -> SymbolTable
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope SinkLore
scope SymbolTable -> SymbolTable -> SymbolTable
forall a. Semigroup a => a -> a -> a
<> SymbolTable
vtable) Sinking
sinking BodyT SinkLore
body
(Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
BodyT SinkLore -> StateT Sunk Identity (BodyT SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SinkLore
body'
}
opMapper :: Scope SinkLore
-> SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
opMapper Scope SinkLore
scope =
SegOpMapper SegLevel Any Any (StateT Sunk Identity)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda SinkLore -> StateT Sunk Identity (Lambda SinkLore)
mapOnSegOpLambda = \Lambda SinkLore
lam -> do
let (BodyT SinkLore
body, Sunk
sunk) =
SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
op_vtable Sinking
sinking (BodyT SinkLore -> (BodyT SinkLore, Sunk))
-> BodyT SinkLore -> (BodyT SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$
Lambda SinkLore -> BodyT SinkLore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SinkLore
lam
(Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
Lambda SinkLore -> StateT Sunk Identity (Lambda SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda SinkLore
lam {lambdaBody :: BodyT SinkLore
lambdaBody = BodyT SinkLore
body},
mapOnSegOpBody :: KernelBody SinkLore -> StateT Sunk Identity (KernelBody SinkLore)
mapOnSegOpBody = \KernelBody SinkLore
body -> do
let (KernelBody SinkLore
body', Sunk
sunk) =
SymbolTable
-> Sinking -> KernelBody SinkLore -> (KernelBody SinkLore, Sunk)
optimiseKernelBody SymbolTable
op_vtable Sinking
sinking KernelBody SinkLore
body
(Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
KernelBody SinkLore -> StateT Sunk Identity (KernelBody SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody SinkLore
body'
}
where
op_vtable :: SymbolTable
op_vtable = Scope SinkLore -> SymbolTable
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope SinkLore
scope SymbolTable -> SymbolTable -> SymbolTable
forall a. Semigroup a => a -> a -> a
<> SymbolTable
vtable
optimiseBody ::
SymbolTable ->
Sinking ->
Body SinkLore ->
(Body SinkLore, Sunk)
optimiseBody :: SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
vtable Sinking
sinking (Body BodyDec SinkLore
dec Stms SinkLore
stms Result
res) =
let (Stms SinkLore
stms', Sunk
sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
in (BodyDec SinkLore -> Stms SinkLore -> Result -> BodyT SinkLore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SinkLore
dec Stms SinkLore
stms' Result
res, Sunk
sunk)
optimiseKernelBody ::
SymbolTable ->
Sinking ->
KernelBody SinkLore ->
(KernelBody SinkLore, Sunk)
optimiseKernelBody :: SymbolTable
-> Sinking -> KernelBody SinkLore -> (KernelBody SinkLore, Sunk)
optimiseKernelBody SymbolTable
vtable Sinking
sinking (KernelBody BodyDec SinkLore
dec Stms SinkLore
stms [KernelResult]
res) =
let (Stms SinkLore
stms', Sunk
sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res
in (BodyDec SinkLore
-> Stms SinkLore -> [KernelResult] -> KernelBody SinkLore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec SinkLore
dec Stms SinkLore
stms' [KernelResult]
res, Sunk
sunk)
sink :: Pass Kernels Kernels
sink :: Pass Kernels Kernels
sink =
String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"sink" String
"move memory loads closer to their uses" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
(Prog SinkLore -> Prog Kernels)
-> PassM (Prog SinkLore) -> PassM (Prog Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog SinkLore -> Prog Kernels
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases
(PassM (Prog SinkLore) -> PassM (Prog Kernels))
-> (Prog Kernels -> PassM (Prog SinkLore))
-> Prog Kernels
-> PassM (Prog Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms SinkLore -> PassM (Stms SinkLore))
-> (Stms SinkLore -> FunDef SinkLore -> PassM (FunDef SinkLore))
-> Prog SinkLore
-> PassM (Prog SinkLore)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms SinkLore -> PassM (Stms SinkLore)
forall (f :: * -> *).
Applicative f =>
Stms SinkLore -> f (Stms SinkLore)
onConsts Stms SinkLore -> FunDef SinkLore -> PassM (FunDef SinkLore)
forall (m :: * -> *) p.
Monad m =>
p -> FunDef SinkLore -> m (FunDef SinkLore)
onFun
(Prog SinkLore -> PassM (Prog SinkLore))
-> (Prog Kernels -> Prog SinkLore)
-> Prog Kernels
-> PassM (Prog SinkLore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog Kernels -> Prog SinkLore
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
Alias.aliasAnalysis
where
onFun :: p -> FunDef SinkLore -> m (FunDef SinkLore)
onFun p
_ FunDef SinkLore
fd = do
let vtable :: SymbolTable
vtable = [FParam SinkLore] -> SymbolTable -> SymbolTable
forall lore.
ASTLore lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams (FunDef SinkLore -> [FParam SinkLore]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SinkLore
fd) SymbolTable
forall a. Monoid a => a
mempty
(BodyT SinkLore
body, Sunk
_) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
vtable Sinking
forall a. Monoid a => a
mempty (BodyT SinkLore -> (BodyT SinkLore, Sunk))
-> BodyT SinkLore -> (BodyT SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ FunDef SinkLore -> BodyT SinkLore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SinkLore
fd
FunDef SinkLore -> m (FunDef SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SinkLore
fd {funDefBody :: BodyT SinkLore
funDefBody = BodyT SinkLore
body}
onConsts :: Stms SinkLore -> f (Stms SinkLore)
onConsts Stms SinkLore
consts =
Stms SinkLore -> f (Stms SinkLore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SinkLore -> f (Stms SinkLore))
-> Stms SinkLore -> f (Stms SinkLore)
forall a b. (a -> b) -> a -> b
$
(Stms SinkLore, Sunk) -> Stms SinkLore
forall a b. (a, b) -> a
fst ((Stms SinkLore, Sunk) -> Stms SinkLore)
-> (Stms SinkLore, Sunk) -> Stms SinkLore
forall a b. (a -> b) -> a -> b
$
SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
forall a. Monoid a => a
mempty Sinking
forall a. Monoid a => a
mempty Stms SinkLore
consts (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Scope SinkLore -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope SinkLore -> [VName]) -> Scope SinkLore -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> Scope SinkLore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SinkLore
consts