module Futhark.Optimise.Fusion.GraphRep
(
EdgeT (..),
NodeT (..),
DepContext,
DepGraphAug,
DepGraph (..),
DepNode,
getName,
nodeFromLNode,
mergedContext,
mapAcross,
edgesBetween,
reachable,
applyAugs,
depsFromEdge,
contractEdge,
isRealNode,
isCons,
isDep,
isInf,
mkDepGraph,
mkDepGraphForFun,
pprg,
)
where
import Control.Monad.Reader
import Data.Bifunctor (bimap)
import Data.Foldable (foldlM)
import Data.Graph.Inductive.Dot qualified as G
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.Graph.Inductive.Tree qualified as G
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Util (nubOrd)
data EdgeT
= Alias VName
| InfDep VName
| Dep VName
| Cons VName
| Fake VName
| Res VName
deriving (EdgeT -> EdgeT -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EdgeT -> EdgeT -> Bool
$c/= :: EdgeT -> EdgeT -> Bool
== :: EdgeT -> EdgeT -> Bool
$c== :: EdgeT -> EdgeT -> Bool
Eq, Eq EdgeT
EdgeT -> EdgeT -> Bool
EdgeT -> EdgeT -> Ordering
EdgeT -> EdgeT -> EdgeT
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EdgeT -> EdgeT -> EdgeT
$cmin :: EdgeT -> EdgeT -> EdgeT
max :: EdgeT -> EdgeT -> EdgeT
$cmax :: EdgeT -> EdgeT -> EdgeT
>= :: EdgeT -> EdgeT -> Bool
$c>= :: EdgeT -> EdgeT -> Bool
> :: EdgeT -> EdgeT -> Bool
$c> :: EdgeT -> EdgeT -> Bool
<= :: EdgeT -> EdgeT -> Bool
$c<= :: EdgeT -> EdgeT -> Bool
< :: EdgeT -> EdgeT -> Bool
$c< :: EdgeT -> EdgeT -> Bool
compare :: EdgeT -> EdgeT -> Ordering
$ccompare :: EdgeT -> EdgeT -> Ordering
Ord)
data NodeT
= StmNode (Stm SOACS)
| SoacNode H.ArrayTransforms (Pat Type) (H.SOAC SOACS) (StmAux (ExpDec SOACS))
|
ResNode VName
|
FreeNode VName
| FinalNode (Stms SOACS) NodeT (Stms SOACS)
| MatchNode (Stm SOACS) [(NodeT, [EdgeT])]
| DoNode (Stm SOACS) [(NodeT, [EdgeT])]
deriving (NodeT -> NodeT -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NodeT -> NodeT -> Bool
$c/= :: NodeT -> NodeT -> Bool
== :: NodeT -> NodeT -> Bool
$c== :: NodeT -> NodeT -> Bool
Eq)
instance Show EdgeT where
show :: EdgeT -> String
show (Dep VName
vName) = String
"Dep " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString VName
vName
show (InfDep VName
vName) = String
"iDep " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString VName
vName
show (Cons VName
_) = String
"Cons"
show (Fake VName
_) = String
"Fake"
show (Res VName
_) = String
"Res"
show (Alias VName
_) = String
"Alias"
instance Show NodeT where
show :: NodeT -> String
show (StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ Exp SOACS
_)) = forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat
show (SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) = forall a. Pretty a => a -> String
prettyString Pat Type
pat
show (FinalNode Stms SOACS
_ NodeT
nt Stms SOACS
_) = forall a. Show a => a -> String
show NodeT
nt
show (ResNode VName
name) = forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ String
"Res: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
show (FreeNode VName
name) = forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ String
"Input: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
show (MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Match: " forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)
show (DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Do: " forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)
getName :: EdgeT -> VName
getName :: EdgeT -> VName
getName EdgeT
edgeT = case EdgeT
edgeT of
Alias VName
vn -> VName
vn
InfDep VName
vn -> VName
vn
Dep VName
vn -> VName
vn
Cons VName
vn -> VName
vn
Fake VName
vn -> VName
vn
Res VName
vn -> VName
vn
isRealNode :: NodeT -> Bool
isRealNode :: NodeT -> Bool
isRealNode ResNode {} = Bool
False
isRealNode FreeNode {} = Bool
False
isRealNode NodeT
_ = Bool
True
pprg :: DepGraph -> String
pprg :: DepGraph -> String
pprg = forall a. Dot a -> String
G.showDot forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gr :: * -> * -> *). Graph gr => gr String String -> Dot ()
G.fglToDotString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gr :: * -> * -> *) a c b d.
DynGraph gr =>
(a -> c) -> (b -> d) -> gr a b -> gr c d
G.nemap forall a. Show a => a -> String
show forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. DepGraph -> Gr NodeT EdgeT
dgGraph
type DepNode = G.LNode NodeT
type DepEdge = G.LEdge EdgeT
type DepContext = G.Context NodeT EdgeT
data DepGraph = DepGraph
{ DepGraph -> Gr NodeT EdgeT
dgGraph :: G.Gr NodeT EdgeT,
DepGraph -> ProducerMapping
dgProducerMapping :: ProducerMapping,
DepGraph -> AliasTable
dgAliasTable :: AliasTable
}
type DepGraphAug m = DepGraph -> m DepGraph
type EdgeGenerator = NodeT -> [(VName, EdgeT)]
type ProducerMapping = M.Map VName G.Node
makeMapping :: Monad m => DepGraphAug m
makeMapping :: forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping dg :: DepGraph
dg@(DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg {dgProducerMapping :: ProducerMapping
dgProducerMapping = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepNode -> [(VName, Int)]
gen_dep_list (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes Gr NodeT EdgeT
g)}
where
gen_dep_list :: DepNode -> [(VName, G.Node)]
gen_dep_list :: DepNode -> [(VName, Int)]
gen_dep_list (Int
i, NodeT
node) = [(VName
name, Int
i) | VName
name <- NodeT -> [VName]
getOutputs NodeT
node]
makeAliasTable :: Monad m => Stms SOACS -> DepGraphAug m
makeAliasTable :: forall (m :: * -> *). Monad m => Stms SOACS -> DepGraphAug m
makeAliasTable Stms SOACS
stms DepGraph
dg = do
let (Stms (Aliases SOACS)
_, (AliasTable
aliasTable', Names
_)) = forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms forall a. Monoid a => a
mempty Stms SOACS
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgAliasTable :: AliasTable
dgAliasTable = AliasTable
aliasTable'}
applyAugs :: Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs :: forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs DepGraph
g = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
($)) DepGraph
g [DepGraphAug m]
augs
genEdges :: Monad m => [DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges :: forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges [DepNode]
l_stms EdgeGenerator
edge_fun DepGraph
dg =
forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ProducerMapping -> DepNode -> [DepEdge]
genEdge (DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg)) [DepNode]
l_stms) DepGraph
dg
where
genEdge :: M.Map VName G.Node -> DepNode -> [G.LEdge EdgeT]
genEdge :: ProducerMapping -> DepNode -> [DepEdge]
genEdge ProducerMapping
name_map (Int
from, NodeT
node) = do
(VName
dep, EdgeT
edgeT) <- EdgeGenerator
edge_fun NodeT
node
Just Int
to <- [forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
dep ProducerMapping
name_map]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b. Edge -> b -> LEdge b
G.toLEdge (Int
from, Int
to) EdgeT
edgeT
depGraphInsertEdges :: Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges :: forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges [DepEdge]
edgs DepGraph
dg = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = forall (gr :: * -> * -> *) b a.
DynGraph gr =>
[LEdge b] -> gr a b -> gr a b
G.insEdges [DepEdge]
edgs forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg}
mapAcross :: Monad m => (DepContext -> m DepContext) -> DepGraphAug m
mapAcross :: forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross DepContext -> m DepContext
f DepGraph
dg = do
Gr NodeT EdgeT
g' <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {gr :: * -> * -> *}.
DynGraph gr =>
Int -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper) (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg) (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Int]
G.nodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg))
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g'}
where
helper :: Int -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper Int
n gr NodeT EdgeT
g' = case forall (gr :: * -> * -> *) a b.
Graph gr =>
Int -> gr a b -> Decomp gr a b
G.match Int
n gr NodeT EdgeT
g' of
(Just DepContext
c, gr NodeT EdgeT
g_new) -> do
DepContext
c' <- DepContext -> m DepContext
f DepContext
c
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepContext
c' forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& gr NodeT EdgeT
g_new
(MContext NodeT EdgeT
Nothing, gr NodeT EdgeT
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure gr NodeT EdgeT
g'
stmFromNode :: NodeT -> Stms SOACS
stmFromNode :: NodeT -> Stms SOACS
stmFromNode (StmNode Stm SOACS
x) = forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
x
stmFromNode NodeT
_ = forall a. Monoid a => a
mempty
nodeFromLNode :: DepNode -> G.Node
nodeFromLNode :: DepNode -> Int
nodeFromLNode = forall a b. (a, b) -> a
fst
depsFromEdge :: DepEdge -> VName
depsFromEdge :: DepEdge -> VName
depsFromEdge = EdgeT -> VName
getName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b. LEdge b -> b
G.edgeLabel
edgesBetween :: DepGraph -> G.Node -> G.Node -> [DepEdge]
edgesBetween :: DepGraph -> Int -> Int -> [DepEdge]
edgesBetween DepGraph
dg Int
n1 Int
n2 = forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b.
DynGraph gr =>
[Int] -> gr a b -> gr a b
G.subgraph [Int
n1, Int
n2] forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
reachable :: DepGraph -> G.Node -> G.Node -> Bool
reachable :: DepGraph -> Int -> Int -> Bool
reachable DepGraph
dg Int
source Int
target = Int
target forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> [Int]
Q.reachable Int
source (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)
augWithFun :: Monad m => EdgeGenerator -> DepGraphAug m
augWithFun :: forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
f DepGraph
dg = forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) EdgeGenerator
f DepGraph
dg
addDeps :: Monad m => DepGraphAug m
addDeps :: forall (m :: * -> *). Monad m => DepGraphAug m
addDeps = forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
toDep
where
toDep :: EdgeGenerator
toDep NodeT
stmt =
let ([VName]
fusible, [VName]
infusible) =
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition ((forall a. Eq a => a -> a -> Bool
== Classification
SOACInput) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList
forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs (NodeT -> Stms SOACS
stmFromNode NodeT
stmt)
mkDep :: VName -> (VName, EdgeT)
mkDep VName
vname = (VName
vname, VName -> EdgeT
Dep VName
vname)
mkInfDep :: VName -> (VName, EdgeT)
mkInfDep VName
vname = (VName
vname, VName -> EdgeT
InfDep VName
vname)
in forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkDep [VName]
fusible forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkInfDep [VName]
infusible
addConsAndAliases :: Monad m => DepGraphAug m
addConsAndAliases :: forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases = forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
edges
where
edges :: EdgeGenerator
edges (StmNode Stm SOACS
s) = forall {k} {rep :: k}. Aliased rep => Exp rep -> [(VName, EdgeT)]
consEdges Exp (Aliases SOACS)
e forall a. Semigroup a => a -> a -> a
<> Exp (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges Exp (Aliases SOACS)
e
where
e :: Exp (Aliases SOACS)
e = forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Exp rep -> Exp (Aliases rep)
Alias.analyseExp forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm SOACS
s
edges NodeT
_ = forall a. Monoid a => a
mempty
consEdges :: Exp rep -> [(VName, EdgeT)]
consEdges Exp rep
e = forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names (forall a b. (a -> b) -> [a] -> [b]
map VName -> EdgeT
Cons [VName]
names)
where
names :: [VName]
names = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Exp rep -> Names
consumedInExp Exp rep
e
aliasEdges :: Exp (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges =
forall a b. (a -> b) -> [a] -> [b]
map (\VName
vname -> (VName
vname, VName -> EdgeT
Alias VName
vname))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Aliased rep => Exp rep -> [Names]
expAliases
addExtraCons :: Monad m => DepGraphAug m
DepGraph
dg =
forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepEdge -> [DepEdge]
makeEdge (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges Gr NodeT EdgeT
g)) DepGraph
dg
where
g :: Gr NodeT EdgeT
g = DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
alias_table :: AliasTable
alias_table = DepGraph -> AliasTable
dgAliasTable DepGraph
dg
mapping :: ProducerMapping
mapping = DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg
makeEdge :: DepEdge -> [DepEdge]
makeEdge (Int
from, Int
to, Cons VName
cname) = do
let aliases :: [VName]
aliases = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
cname AliasTable
alias_table
to' :: [Int]
to' = forall a b. (a -> b) -> [a] -> [b]
map (ProducerMapping
mapping M.!) [VName]
aliases
p :: (Int, EdgeT) -> Bool
p (Int
tonode, EdgeT
toedge) =
Int
tonode forall a. Eq a => a -> a -> Bool
/= Int
from Bool -> Bool -> Bool
&& EdgeT -> VName
getName EdgeT
toedge forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (VName
cname forall a. a -> [a] -> [a]
: [VName]
aliases)
(Int
to2, EdgeT
_) <- forall a. (a -> Bool) -> [a] -> [a]
filter (Int, EdgeT) -> Bool
p forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g) [Int]
to' forall a. Semigroup a => a -> a -> a
<> forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g Int
to
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b. Edge -> b -> LEdge b
G.toLEdge (Int
from, Int
to2) (VName -> EdgeT
Fake VName
cname)
makeEdge DepEdge
_ = []
mapAcrossNodeTs :: Monad m => (NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs :: forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs NodeT -> m NodeT
f = forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross forall {a} {b} {d}. (a, b, NodeT, d) -> m (a, b, NodeT, d)
f'
where
f' :: (a, b, NodeT, d) -> m (a, b, NodeT, d)
f' (a
ins, b
n, NodeT
nodeT, d
outs) = do
NodeT
nodeT' <- NodeT -> m NodeT
f NodeT
nodeT
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
ins, b
n, NodeT
nodeT', d
outs)
nodeToSoacNode :: (HasScope SOACS m, Monad m) => NodeT -> m NodeT
nodeToSoacNode :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode n :: NodeT
n@(StmNode s :: Stm SOACS
s@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
op)) = case Exp SOACS
op of
Op {} -> do
Either NotSOAC (SOAC SOACS)
maybeSoac <- forall {k} (rep :: k) (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
H.fromExp Exp SOACS
op
case Either NotSOAC (SOAC SOACS)
maybeSoac of
Right SOAC SOACS
hsoac -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode forall a. Monoid a => a
mempty Pat (LetDec SOACS)
pat SOAC SOACS
hsoac StmAux (ExpDec SOACS)
aux
Left NotSOAC
H.NotSOAC -> forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
DoLoop {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode Stm SOACS
s []
Match {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
s []
Exp SOACS
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
nodeToSoacNode NodeT
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
convertGraph :: (HasScope SOACS m, Monad m) => DepGraphAug m
convertGraph :: forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
convertGraph = forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode
initialGraphConstruction :: (HasScope SOACS m, Monad m) => DepGraphAug m
initialGraphConstruction :: forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
initialGraphConstruction =
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs
[ forall (m :: * -> *). Monad m => DepGraphAug m
addDeps,
forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases,
forall (m :: * -> *). Monad m => DepGraphAug m
addExtraCons,
forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges,
forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
convertGraph
]
emptyGraph :: Body SOACS -> DepGraph
emptyGraph :: Body SOACS -> DepGraph
emptyGraph Body SOACS
body =
DepGraph
{ dgGraph :: Gr NodeT EdgeT
dgGraph = forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
G.mkGraph (forall {b}. [b] -> [(Int, b)]
labelNodes ([NodeT]
stmnodes forall a. Semigroup a => a -> a -> a
<> [NodeT]
resnodes forall a. Semigroup a => a -> a -> a
<> [NodeT]
inputnodes)) [],
dgProducerMapping :: ProducerMapping
dgProducerMapping = forall a. Monoid a => a
mempty,
dgAliasTable :: AliasTable
dgAliasTable = forall a. Monoid a => a
mempty
}
where
labelNodes :: [b] -> [(Int, b)]
labelNodes = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
stmnodes :: [NodeT]
stmnodes = forall a b. (a -> b) -> [a] -> [b]
map Stm SOACS -> NodeT
StmNode forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body
resnodes :: [NodeT]
resnodes = forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
ResNode forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body
inputnodes :: [NodeT]
inputnodes = forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
FreeNode forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
mkDepGraph :: (HasScope SOACS m, Monad m) => Body SOACS -> m DepGraph
mkDepGraph :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body = forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs forall a b. (a -> b) -> a -> b
$ Body SOACS -> DepGraph
emptyGraph Body SOACS
body
where
augs :: [DepGraphAug m]
augs =
[ forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping,
forall (m :: * -> *). Monad m => Stms SOACS -> DepGraphAug m
makeAliasTable (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body),
forall (m :: * -> *). (HasScope SOACS m, Monad m) => DepGraphAug m
initialGraphConstruction
]
mkDepGraphForFun :: FunDef SOACS -> DepGraph
mkDepGraphForFun :: FunDef SOACS -> DepGraph
mkDepGraphForFun FunDef SOACS
f = forall r a. Reader r a -> r -> a
runReader (forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph (forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef SOACS
f)) Scope SOACS
scope
where
scope :: Scope SOACS
scope = forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall {k} (rep :: k). FunDef rep -> [FParam rep]
funDefParams FunDef SOACS
f) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef SOACS
f))
mergedContext :: Ord b => a -> G.Context a b -> G.Context a b -> G.Context a b
mergedContext :: forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext a
mergedlabel (Adj b
inp1, Int
n1, a
_, Adj b
out1) (Adj b
inp2, Int
n2, a
_, Adj b
out2) =
let new_inp :: Adj b
new_inp = forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Int)
n -> forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n1 Bool -> Bool -> Bool
&& forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n2) (forall a. Ord a => [a] -> [a]
nubOrd (Adj b
inp1 forall a. Semigroup a => a -> a -> a
<> Adj b
inp2))
new_out :: Adj b
new_out = forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Int)
n -> forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n1 Bool -> Bool -> Bool
&& forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n2) (forall a. Ord a => [a] -> [a]
nubOrd (Adj b
out1 forall a. Semigroup a => a -> a -> a
<> Adj b
out2))
in (Adj b
new_inp, Int
n1, a
mergedlabel, Adj b
new_out)
contractEdge :: Monad m => G.Node -> DepContext -> DepGraphAug m
contractEdge :: forall (m :: * -> *). Monad m => Int -> DepContext -> DepGraphAug m
contractEdge Int
n2 DepContext
ctx DepGraph
dg = do
let n1 :: Int
n1 = forall a b. Context a b -> Int
G.node' DepContext
ctx
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = DepContext
ctx forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& forall (gr :: * -> * -> *) a b.
Graph gr =>
[Int] -> gr a b -> gr a b
G.delNodes [Int
n1, Int
n2] (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)}
addResEdges :: Monad m => DepGraphAug m
addResEdges :: forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges = forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
getStmRes
data Classification
=
SOACInput
|
Other
deriving (Classification -> Classification -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Classification -> Classification -> Bool
$c/= :: Classification -> Classification -> Bool
== :: Classification -> Classification -> Bool
$c== :: Classification -> Classification -> Bool
Eq, Eq Classification
Classification -> Classification -> Bool
Classification -> Classification -> Ordering
Classification -> Classification -> Classification
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Classification -> Classification -> Classification
$cmin :: Classification -> Classification -> Classification
max :: Classification -> Classification -> Classification
$cmax :: Classification -> Classification -> Classification
>= :: Classification -> Classification -> Bool
$c>= :: Classification -> Classification -> Bool
> :: Classification -> Classification -> Bool
$c> :: Classification -> Classification -> Bool
<= :: Classification -> Classification -> Bool
$c<= :: Classification -> Classification -> Bool
< :: Classification -> Classification -> Bool
$c< :: Classification -> Classification -> Bool
compare :: Classification -> Classification -> Ordering
$ccompare :: Classification -> Classification -> Ordering
Ord, Int -> Classification -> ShowS
[Classification] -> ShowS
Classification -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Classification] -> ShowS
$cshowList :: [Classification] -> ShowS
show :: Classification -> String
$cshow :: Classification -> String
showsPrec :: Int -> Classification -> ShowS
$cshowsPrec :: Int -> Classification -> ShowS
Show)
type Classifications = S.Set (VName, Classification)
freeClassifications :: FreeIn a => a -> Classifications
freeClassifications :: forall a. FreeIn a => a -> Classifications
freeClassifications =
forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. [a] -> [b] -> [(a, b)]
`zip` forall a. a -> [a]
repeat Classification
Other) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn
stmInputs :: Stm SOACS -> Classifications
stmInputs :: Stm SOACS -> Classifications
stmInputs (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
forall a. FreeIn a => a -> Classifications
freeClassifications (Pat (LetDec SOACS)
pat, StmAux (ExpDec SOACS)
aux) forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> Classifications
expInputs Exp SOACS
e
bodyInputs :: Body SOACS -> Classifications
bodyInputs :: Body SOACS -> Classifications
bodyInputs (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs Stms SOACS
stms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications Result
res
expInputs :: Exp SOACS -> Classifications
expInputs :: Exp SOACS -> Classifications
expInputs (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
attr) =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body SOACS -> Classifications
bodyInputs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body SOACS)]
cases
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
defbody
forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications ([SubExp]
cond, MatchDec (BranchType SOACS)
attr)
expInputs (DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
b1) =
forall a. FreeIn a => a -> Classifications
freeClassifications ([(FParam SOACS, SubExp)]
params, LoopForm SOACS
form) forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
b1
expInputs (Op Op SOACS
soac) = case Op SOACS
soac of
Futhark.Screma SubExp
w [VName]
is ScremaForm SOACS
form -> [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, ScremaForm SOACS
form)
Futhark.Hist SubExp
w [VName]
is [HistOp SOACS]
ops Lambda SOACS
lam -> [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [HistOp SOACS]
ops, Lambda SOACS
lam)
Futhark.Scatter SubExp
w [VName]
is Lambda SOACS
lam [(Shape, Int, VName)]
iws -> [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, Lambda SOACS
lam, [(Shape, Int, VName)]
iws)
Futhark.Stream SubExp
w [VName]
is [SubExp]
nes Lambda SOACS
lam ->
[VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [SubExp]
nes, Lambda SOACS
lam)
Futhark.JVP {} -> forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
soac
Futhark.VJP {} -> forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
soac
where
inputs :: [VName] -> Classifications
inputs = forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. [a] -> [b] -> [(a, b)]
`zip` forall a. a -> [a]
repeat Classification
SOACInput)
expInputs Exp SOACS
e
| Just (VName
arr, ArrayTransform
_) <- forall {k} (rep :: k).
Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp forall a. Monoid a => a
mempty Exp SOACS
e =
forall a. a -> Set a
S.singleton (VName
arr, Classification
SOACInput)
forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (forall a. FreeIn a => a -> Names
freeIn Exp SOACS
e Names -> Names -> Names
`namesSubtract` VName -> Names
oneName VName
arr)
| Bool
otherwise = forall a. FreeIn a => a -> Classifications
freeClassifications Exp SOACS
e
stmNames :: Stm SOACS -> [VName]
stmNames :: Stm SOACS -> [VName]
stmNames = forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat
getStmRes :: EdgeGenerator
getStmRes :: EdgeGenerator
getStmRes (ResNode VName
name) = [(VName
name, VName -> EdgeT
Res VName
name)]
getStmRes NodeT
_ = []
getOutputs :: NodeT -> [VName]
getOutputs :: NodeT -> [VName]
getOutputs NodeT
node = case NodeT
node of
(StmNode Stm SOACS
stm) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
(ResNode VName
_) -> []
(FreeNode VName
name) -> [VName
name]
(MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
(DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
FinalNode {} -> forall a. HasCallStack => String -> a
error String
"Final nodes cannot generate edges"
(SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) -> forall dec. Pat dec -> [VName]
patNames Pat Type
pat
isDep :: EdgeT -> Bool
isDep :: EdgeT -> Bool
isDep (Dep VName
_) = Bool
True
isDep (Res VName
_) = Bool
True
isDep EdgeT
_ = Bool
False
isInf :: (G.Node, G.Node, EdgeT) -> Bool
isInf :: DepEdge -> Bool
isInf (Int
_, Int
_, EdgeT
e) = case EdgeT
e of
InfDep VName
_ -> Bool
True
Fake VName
_ -> Bool
True
EdgeT
_ -> Bool
False
isCons :: EdgeT -> Bool
isCons :: EdgeT -> Bool
isCons (Cons VName
_) = Bool
True
isCons EdgeT
_ = Bool
False