{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Sink (sinkGPU, sinkMC) where
import Control.Monad.State
import Data.Bifunctor
import Data.List (foldl')
import Data.Map qualified as M
import Data.Sequence ((<|))
import Data.Sequence qualified as SQ
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Builder.Class
import Futhark.Construct (sliceDim)
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.Pass
type SymbolTable rep = ST.SymbolTable rep
type Sinking rep = M.Map VName (Stm rep)
type Sunk = Names
type Sinker rep a = SymbolTable rep -> Sinking rep -> a -> (a, Sunk)
type Constraints rep =
( ASTRep rep,
Aliased rep,
Buildable rep,
ST.IndexOp (Op rep)
)
multiplicity :: Constraints rep => Stm rep -> M.Map VName Int
multiplicity :: forall {k} (rep :: k). Constraints rep => Stm rep -> Map VName Int
multiplicity Stm rep
stm =
case forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm of
Match [SubExp]
cond [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_ ->
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName Int -> Map VName Int -> Map VName Int
comb forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 [SubExp]
cond
forall a. a -> [a] -> [a]
: forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 Body rep
defbody
forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
Op {} -> forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
2 Stm rep
stm
DoLoop {} -> forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
2 Stm rep
stm
Exp rep
_ -> forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 Stm rep
stm
where
free :: a -> a -> Map VName a
free a
k a
x = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (Sunk -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Sunk
freeIn a
x) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat a
k
comb :: Map VName Int -> Map VName Int -> Map VName Int
comb = forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith forall a. Num a => a -> a -> a
(+)
optimiseBranch ::
Constraints rep =>
Sinker rep (Op rep) ->
Sinker rep (Body rep)
optimiseBranch :: forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (Body BodyDec rep
dec Stms rep
stms Result
res) =
let (Stms rep
stms', Sunk
stms_sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking' (Stms rep
sunk_stms forall a. Semigroup a => a -> a -> a
<> Stms rep
stms) forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Sunk
freeIn Result
res
in ( forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms' Result
res,
Sunk
sunk forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
)
where
free_in_stms :: Sunk
free_in_stms = forall a. FreeIn a => a -> Sunk
freeIn Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Sunk
freeIn Result
res
(Sinking rep
sinking_here, Sinking rep
sinking') = forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm rep -> Bool
sunkHere Sinking rep
sinking
sunk_stms :: Stms rep
sunk_stms = forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Sinking rep
sinking_here
sunkHere :: VName -> Stm rep -> Bool
sunkHere VName
v Stm rep
stm =
VName
v
VName -> Sunk -> Bool
`nameIn` Sunk
free_in_stms
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable) (Sunk -> [VName]
namesToList (forall a. FreeIn a => a -> Sunk
freeIn Stm rep
stm))
sunk :: Sunk
sunk = [VName] -> Sunk
namesFromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) Stms rep
sunk_stms
optimiseLoop ::
Constraints rep =>
Sinker rep (Op rep) ->
Sinker rep ([(FParam rep, SubExp)], LoopForm rep, Body rep)
optimiseLoop :: forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> Sinker rep ([(FParam rep, SubExp)], LoopForm rep, Body rep)
optimiseLoop Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking ([(FParam rep, SubExp)]
merge, LoopForm rep
form, Body rep
body0)
| WhileLoop {} <- LoopForm rep
form =
let (Body rep
body1, Sunk
sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable' Sinking rep
sinking Body rep
body0
in (([(FParam rep, SubExp)]
merge, LoopForm rep
form, Body rep
body1), Sunk
sunk)
| ForLoop VName
i IntType
it SubExp
bound [(LParam rep, VName)]
loop_vars <- LoopForm rep
form =
let stms' :: Seq (Stm rep)
stms' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall {k} {rep :: k} {dec}.
(Buildable rep, Typed dec) =>
VName -> (Param dec, VName) -> Seq (Stm rep) -> Seq (Stm rep)
inline VName
i) (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body0) [(LParam rep, VName)]
loop_vars
body1 :: Body rep
body1 = Body rep
body0 {bodyStms :: Seq (Stm rep)
bodyStms = Seq (Stm rep)
stms'}
(Body rep
body2, Sunk
sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable' Sinking rep
sinking Body rep
body1
notSunk :: (Param Type, VName) -> Bool
notSunk (Param Type
x, VName
_) = forall dec. Param dec -> VName
paramName Param Type
x VName -> Sunk -> Bool
`notNameIn` Sunk
sunk
loop_vars' :: [(Param Type, VName)]
loop_vars' = forall a. (a -> Bool) -> [a] -> [a]
filter (Param Type, VName) -> Bool
notSunk [(LParam rep, VName)]
loop_vars
form' :: LoopForm rep
form' = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(Param Type, VName)]
loop_vars'
body3 :: Body rep
body3 = Body rep
body2 {bodyStms :: Seq (Stm rep)
bodyStms = forall a. Int -> Seq a -> Seq a
SQ.drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param Type, VName)]
loop_vars') (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body2)}
in (([(FParam rep, SubExp)]
merge, LoopForm rep
form', Body rep
body3), Sunk
sunk)
where
([Param DeclType]
params, [SubExp]
_) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam rep, SubExp)]
merge
scope :: Scope rep
scope = case LoopForm rep
form of
WhileLoop {} -> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
ForLoop VName
i IntType
it SubExp
_ [(LParam rep, VName)]
_ -> forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
it) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
vtable' :: SymbolTable rep
vtable' = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable
inline :: VName -> (Param dec, VName) -> Seq (Stm rep) -> Seq (Stm rep)
inline VName
i (Param dec
x, VName
arr) Seq (Stm rep)
stms =
let pt :: Type
pt = forall t. Typed t => t -> Type
typeOf Param dec
x
slice :: Slice SubExp
slice = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
pt)
e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
pat :: Pat (LetDec rep)
pat = forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Pat (LetDec rep)
mkExpPat [VName -> Type -> Ident
Ident (forall dec. Param dec -> VName
paramName Param dec
x) Type
pt] Exp rep
e
aux :: StmAux (ExpDec rep)
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty (forall {k} (rep :: k).
Buildable rep =>
Pat (LetDec rep) -> Exp rep -> ExpDec rep
mkExpDec Pat (LetDec rep)
pat Exp rep
e)
stm :: Stm rep
stm = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e
in Stm rep
stm forall a. a -> Seq a -> Seq a
<| Seq (Stm rep)
stms
optimiseStms ::
Constraints rep =>
Sinker rep (Op rep) ->
SymbolTable rep ->
Sinking rep ->
Stms rep ->
Names ->
(Stms rep, Sunk)
optimiseStms :: forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
init_vtable Sinking rep
init_sinking Stms rep
all_stms Sunk
free_in_res =
let ([Stm rep]
all_stms', Sunk
sunk) =
SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
init_vtable Sinking rep
init_sinking forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
in (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm rep]
all_stms', Sunk
sunk)
where
multiplicities :: Map VName Int
multiplicities =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
(forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith forall a. Num a => a -> a -> a
(+))
(forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip (Sunk -> [VName]
namesToList Sunk
free_in_res) (forall a. a -> [a]
repeat Int
1)))
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Constraints rep => Stm rep -> Map VName Int
multiplicity forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms)
optimiseStms' :: SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
_ Sinking rep
_ [] = ([], forall a. Monoid a => a
mempty)
optimiseStms' SymbolTable rep
vtable Sinking rep
sinking (Stm rep
stm : [Stm rep]
stms)
| BasicOp Index {} <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm,
[PatElem (LetDec rep)
pe] <- forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm),
forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe,
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Int
1) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Map VName Int
multiplicities =
let ([Stm rep]
stms', Sunk
sunk) =
SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Stm rep
stm Sinking rep
sinking) [Stm rep]
stms
in if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Sunk -> Bool
`nameIn` Sunk
sunk
then ([Stm rep]
stms', Sunk
sunk)
else (Stm rep
stm forall a. a -> [a] -> [a]
: [Stm rep]
stms', Sunk
sunk)
| Match [SubExp]
cond [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
ret <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
let onCase :: Case (Body rep) -> (Case (Body rep), Sunk)
onCase (Case [Maybe PrimValue]
vs Body rep
body) =
let (Body rep
body', Sunk
body_sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Body rep
body
in (forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs Body rep
body', Sunk
body_sunk)
([Case (Body rep)]
cases', [Sunk]
cases_sunk) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Case (Body rep) -> (Case (Body rep), Sunk)
onCase [Case (Body rep)]
cases
(Body rep
defbody', Sunk
defbody_sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Body rep
defbody
([Stm rep]
stms', Sunk
sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
in ( Stm rep
stm {stmExp :: Exp rep
stmExp = forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' Body rep
defbody' MatchDec (BranchType rep)
ret} forall a. a -> [a] -> [a]
: [Stm rep]
stms',
forall a. Monoid a => [a] -> a
mconcat [Sunk]
cases_sunk forall a. Semigroup a => a -> a -> a
<> Sunk
defbody_sunk forall a. Semigroup a => a -> a -> a
<> Sunk
sunk
)
| DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
lform Body rep
body <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
let comps :: ([(Param DeclType, SubExp)], LoopForm rep, Body rep)
comps = ([(FParam rep, SubExp)]
merge, LoopForm rep
lform, Body rep
body)
(([(FParam rep, SubExp)], LoopForm rep, Body rep)
comps', Sunk
loop_sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> Sinker rep ([(FParam rep, SubExp)], LoopForm rep, Body rep)
optimiseLoop Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking ([(Param DeclType, SubExp)], LoopForm rep, Body rep)
comps
([(FParam rep, SubExp)]
merge', LoopForm rep
lform', Body rep
body') = ([(FParam rep, SubExp)], LoopForm rep, Body rep)
comps'
([Stm rep]
stms', Sunk
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
in ( Stm rep
stm {stmExp :: Exp rep
stmExp = forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
merge' LoopForm rep
lform' Body rep
body'} forall a. a -> [a] -> [a]
: [Stm rep]
stms',
Sunk
stms_sunk forall a. Semigroup a => a -> a -> a
<> Sunk
loop_sunk
)
| Op Op rep
op <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
let (Op rep
op', Sunk
op_sunk) = Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Op rep
op
([Stm rep]
stms', Sunk
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
in ( Stm rep
stm {stmExp :: Exp rep
stmExp = forall {k} (rep :: k). Op rep -> Exp rep
Op Op rep
op'} forall a. a -> [a] -> [a]
: [Stm rep]
stms',
Sunk
stms_sunk forall a. Semigroup a => a -> a -> a
<> Sunk
op_sunk
)
| Bool
otherwise =
let ([Stm rep]
stms', Sunk
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
(Exp rep
e', Sunk
stm_sunk) = forall s a. State s a -> s -> (a, s)
runState (forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT Sunk Identity)
mapper (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm)) forall a. Monoid a => a
mempty
in ( Stm rep
stm {stmExp :: Exp rep
stmExp = Exp rep
e'} forall a. a -> [a] -> [a]
: [Stm rep]
stms',
Sunk
stm_sunk forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
)
where
vtable' :: SymbolTable rep
vtable' = forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm Stm rep
stm SymbolTable rep
vtable
mapper :: Mapper rep rep (StateT Sunk Identity)
mapper =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope rep -> Body rep -> StateT Sunk Identity (Body rep)
mapOnBody = \Scope rep
scope Body rep
body -> do
let (Body rep
body', Sunk
sunk) =
forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody
Sinker rep (Op rep)
onOp
(forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable)
Sinking rep
sinking
Body rep
body
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body'
}
optimiseBody ::
Constraints rep =>
Sinker rep (Op rep) ->
Sinker rep (Body rep)
optimiseBody :: forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (Body BodyDec rep
attr Stms rep
stms Result
res) =
let (Stms rep
stms', Sunk
sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Stms rep
stms forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Sunk
freeIn Result
res
in (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
attr Stms rep
stms' Result
res, Sunk
sunk)
optimiseKernelBody ::
Constraints rep =>
Sinker rep (Op rep) ->
Sinker rep (KernelBody rep)
optimiseKernelBody :: forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
optimiseKernelBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (KernelBody BodyDec rep
attr Stms rep
stms [KernelResult]
res) =
let (Stms rep
stms', Sunk
sunk) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Stms rep
stms forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Sunk
freeIn [KernelResult]
res
in (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
attr Stms rep
stms' [KernelResult]
res, Sunk
sunk)
optimiseSegOp ::
Constraints rep =>
Sinker rep (Op rep) ->
Sinker rep (SegOp lvl rep)
optimiseSegOp :: forall {k} (rep :: k) lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking SegOp lvl rep
op =
let scope :: Scope rep
scope = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op
in forall s a. State s a -> s -> (a, s)
runState (forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM (Scope rep -> SegOpMapper lvl rep rep (StateT Sunk Identity)
opMapper Scope rep
scope) SegOp lvl rep
op) forall a. Monoid a => a
mempty
where
opMapper :: Scope rep -> SegOpMapper lvl rep rep (StateT Sunk Identity)
opMapper Scope rep
scope =
forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda rep -> StateT Sunk Identity (Lambda rep)
mapOnSegOpLambda = \Lambda rep
lam -> do
let (Body rep
body, Sunk
sunk) =
forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
op_vtable Sinking rep
sinking forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body},
mapOnSegOpBody :: KernelBody rep -> StateT Sunk Identity (KernelBody rep)
mapOnSegOpBody = \KernelBody rep
body -> do
let (KernelBody rep
body', Sunk
sunk) =
forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
optimiseKernelBody Sinker rep (Op rep)
onOp SymbolTable rep
op_vtable Sinking rep
sinking KernelBody rep
body
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody rep
body'
}
where
op_vtable :: SymbolTable rep
op_vtable = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable
type SinkRep rep = Aliases rep
sink ::
( Buildable rep,
CanBeAliased (Op rep),
ST.IndexOp (OpWithAliases (Op rep))
) =>
Sinker (SinkRep rep) (Op (SinkRep rep)) ->
Pass rep rep
sink :: forall {k1} (rep :: k1).
(Buildable rep, CanBeAliased (Op rep),
IndexOp (OpWithAliases (Op rep))) =>
Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep
sink Sinker (SinkRep rep) (Op (SinkRep rep))
onOp =
forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"sink" String
"move memory loads closer to their uses" forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Prog (Aliases rep) -> Prog rep
removeProgAliases
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (SinkRep rep) -> PassM (Stms (SinkRep rep))
onConsts Stms (SinkRep rep)
-> FunDef (SinkRep rep) -> PassM (FunDef (SinkRep rep))
onFun
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
Alias.aliasAnalysis
where
onFun :: Stms (SinkRep rep)
-> FunDef (SinkRep rep) -> PassM (FunDef (SinkRep rep))
onFun Stms (SinkRep rep)
_ FunDef (SinkRep rep)
fd = do
let vtable :: SymbolTable (SinkRep rep)
vtable = forall {k} (rep :: k).
ASTRep rep =>
[FParam rep] -> SymbolTable rep -> SymbolTable rep
ST.insertFParams (forall {k} (rep :: k). FunDef rep -> [FParam rep]
funDefParams FunDef (SinkRep rep)
fd) forall a. Monoid a => a
mempty
(Body (SinkRep rep)
body, Sunk
_) = forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker (SinkRep rep) (Op (SinkRep rep))
onOp SymbolTable (SinkRep rep)
vtable forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef (SinkRep rep)
fd
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef (SinkRep rep)
fd {funDefBody :: Body (SinkRep rep)
funDefBody = Body (SinkRep rep)
body}
onConsts :: Stms (SinkRep rep) -> PassM (Stms (SinkRep rep))
onConsts Stms (SinkRep rep)
consts =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker (SinkRep rep) (Op (SinkRep rep))
onOp forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty Stms (SinkRep rep)
consts forall a b. (a -> b) -> a -> b
$
[VName] -> Sunk
namesFromList forall a b. (a -> b) -> a -> b
$
forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms (SinkRep rep)
consts
sinkGPU :: Pass GPU GPU
sinkGPU :: Pass GPU GPU
sinkGPU = forall {k1} (rep :: k1).
(Buildable rep, CanBeAliased (Op rep),
IndexOp (OpWithAliases (Op rep))) =>
Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep
sink Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp
where
onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking (SegOp SegOp SegLevel (SinkRep GPU)
op) =
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking SegOp SegLevel (SinkRep GPU)
op
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking (GPUBody [Type]
types Body (SinkRep GPU)
body) =
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking Body (SinkRep GPU)
body
onHostOp SymbolTable (SinkRep GPU)
_ Sinking (SinkRep GPU)
_ Op (SinkRep GPU)
op = (Op (SinkRep GPU)
op, forall a. Monoid a => a
mempty)
sinkMC :: Pass MC MC
sinkMC :: Pass MC MC
sinkMC = forall {k1} (rep :: k1).
(Buildable rep, CanBeAliased (Op rep),
IndexOp (OpWithAliases (Op rep))) =>
Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep
sink Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp
where
onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking (ParOp Maybe (SegOp () (SinkRep MC))
par_op SegOp () (SinkRep MC)
op) =
let (Maybe (SegOp () (SinkRep MC))
par_op', Sunk
par_sunk) =
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
(forall a. Maybe a
Nothing, forall a. Monoid a => a
mempty)
(forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking)
Maybe (SegOp () (SinkRep MC))
par_op
(SegOp () (SinkRep MC)
op', Sunk
sunk) = forall {k} (rep :: k) lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking SegOp () (SinkRep MC)
op
in (forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () (SinkRep MC))
par_op' SegOp () (SinkRep MC)
op', Sunk
par_sunk forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
onHostOp SymbolTable (SinkRep MC)
_ Sinking (SinkRep MC)
_ Op (SinkRep MC)
op = (Op (SinkRep MC)
op, forall a. Monoid a => a
mempty)