-- | A graph representation of a sequence of Futhark statements
-- (i.e. a 'Body'), built to handle fusion.  Could perhaps be made
-- more general.  An important property is that it does not handle
-- "nested bodies" (e.g. 'Match'); these are represented as single
-- nodes.
--
-- This is all implemented on top of the graph representation provided
-- by the @fgl@ package ("Data.Graph.Inductive").  The graph provided
-- by this package allows nodes and edges to have arbitrarily-typed
-- "labels".  It is these labels ('EdgeT', 'NodeT') that we use to
-- contain Futhark-specific information.  An edge goes *from* uses of
-- variables to the node that produces that variable.  There are also
-- edges that do not represent normal data dependencies, but other
-- things.  This means that a node can have multiple edges for the
-- same name, indicating different kinds of dependencies.
module Futhark.Optimise.Fusion.GraphRep
  ( -- * Data structure
    EdgeT (..),
    NodeT (..),
    DepContext,
    DepGraphAug,
    DepGraph (..),
    DepNode,

    -- * Queries
    getName,
    nodeFromLNode,
    mergedContext,
    mapAcross,
    edgesBetween,
    reachable,
    applyAugs,
    depsFromEdge,
    contractEdge,
    isRealNode,
    isCons,
    isDep,
    isInf,

    -- * Construction
    mkDepGraph,
    mkDepGraphForFun,
    pprg,
  )
where

import Control.Monad.Reader
import Data.Bifunctor (bimap)
import Data.Foldable (foldlM)
import Data.Graph.Inductive.Dot qualified as G
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.Graph.Inductive.Tree qualified as G
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe (mapMaybe)
import Data.Set qualified as S
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Util (nubOrd)

-- | Information associated with an edge in the graph.
data EdgeT
  = Alias VName
  | InfDep VName
  | Dep VName
  | Cons VName
  | Fake VName
  | Res VName
  deriving (EdgeT -> EdgeT -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EdgeT -> EdgeT -> Bool
$c/= :: EdgeT -> EdgeT -> Bool
== :: EdgeT -> EdgeT -> Bool
$c== :: EdgeT -> EdgeT -> Bool
Eq, Eq EdgeT
EdgeT -> EdgeT -> Bool
EdgeT -> EdgeT -> Ordering
EdgeT -> EdgeT -> EdgeT
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EdgeT -> EdgeT -> EdgeT
$cmin :: EdgeT -> EdgeT -> EdgeT
max :: EdgeT -> EdgeT -> EdgeT
$cmax :: EdgeT -> EdgeT -> EdgeT
>= :: EdgeT -> EdgeT -> Bool
$c>= :: EdgeT -> EdgeT -> Bool
> :: EdgeT -> EdgeT -> Bool
$c> :: EdgeT -> EdgeT -> Bool
<= :: EdgeT -> EdgeT -> Bool
$c<= :: EdgeT -> EdgeT -> Bool
< :: EdgeT -> EdgeT -> Bool
$c< :: EdgeT -> EdgeT -> Bool
compare :: EdgeT -> EdgeT -> Ordering
$ccompare :: EdgeT -> EdgeT -> Ordering
Ord)

-- | Information associated with a node in the graph.
data NodeT
  = StmNode (Stm SOACS)
  | SoacNode H.ArrayTransforms (Pat Type) (H.SOAC SOACS) (StmAux (ExpDec SOACS))
  | -- | First 'VName' is result; last is input.
    TransNode VName H.ArrayTransform VName
  | -- | Node corresponding to a result of the entire computation
    -- (i.e. the 'Result' of a body).  Any node that is not
    -- transitively reachable from one of these can be considered
    -- dead.
    ResNode VName
  | -- | Node corresponding to a free variable.  These are used to
    -- safely handle consumption, which also means we don't have to
    -- create a node for every free single variable.
    FreeNode VName
  | MatchNode (Stm SOACS) [(NodeT, [EdgeT])]
  | DoNode (Stm SOACS) [(NodeT, [EdgeT])]
  deriving (NodeT -> NodeT -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NodeT -> NodeT -> Bool
$c/= :: NodeT -> NodeT -> Bool
== :: NodeT -> NodeT -> Bool
$c== :: NodeT -> NodeT -> Bool
Eq)

instance Show EdgeT where
  show :: EdgeT -> String
show (Dep VName
vName) = String
"Dep " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString VName
vName
  show (InfDep VName
vName) = String
"iDep " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString VName
vName
  show (Cons VName
_) = String
"Cons"
  show (Fake VName
_) = String
"Fake"
  show (Res VName
_) = String
"Res"
  show (Alias VName
_) = String
"Alias"

instance Show NodeT where
  show :: NodeT -> String
show (StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ Exp SOACS
_)) = forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat
  show (SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) = forall a. Pretty a => a -> String
prettyString Pat Type
pat
  show (TransNode VName
_ ArrayTransform
tr VName
_) = forall a. Pretty a => a -> String
prettyString (forall a. Show a => a -> String
show ArrayTransform
tr)
  show (ResNode VName
name) = forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ String
"Res: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
  show (FreeNode VName
name) = forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ String
"Input: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
  show (MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Match: " forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)
  show (DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) = String
"Do: " forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " (forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> String
prettyString forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [VName]
stmNames Stm SOACS
stm)

-- | The name that this edge depends on.
getName :: EdgeT -> VName
getName :: EdgeT -> VName
getName EdgeT
edgeT = case EdgeT
edgeT of
  Alias VName
vn -> VName
vn
  InfDep VName
vn -> VName
vn
  Dep VName
vn -> VName
vn
  Cons VName
vn -> VName
vn
  Fake VName
vn -> VName
vn
  Res VName
vn -> VName
vn

-- | Does the node acutally represent something in the program?  A
-- "non-real" node represents things like fake nodes inserted to
-- express ordering due to consumption.
isRealNode :: NodeT -> Bool
isRealNode :: NodeT -> Bool
isRealNode ResNode {} = Bool
False
isRealNode FreeNode {} = Bool
False
isRealNode NodeT
_ = Bool
True

-- | Prettyprint dependency graph.
pprg :: DepGraph -> String
pprg :: DepGraph -> String
pprg = forall a. Dot a -> String
G.showDot forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gr :: * -> * -> *). Graph gr => gr String String -> Dot ()
G.fglToDotString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gr :: * -> * -> *) a c b d.
DynGraph gr =>
(a -> c) -> (b -> d) -> gr a b -> gr c d
G.nemap forall a. Show a => a -> String
show forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. DepGraph -> Gr NodeT EdgeT
dgGraph

-- | A pair of a 'G.Node' and the node label.
type DepNode = G.LNode NodeT

type DepEdge = G.LEdge EdgeT

-- | A tuple with four parts: inbound links to the node, the node
-- itself, the 'NodeT' "label", and outbound links from the node.
-- This type is used to modify the graph in 'mapAcross'.
type DepContext = G.Context NodeT EdgeT

-- | A dependency graph.  Edges go from *consumers* to *producers*
-- (i.e. from usage to definition).  That means the incoming edges of
-- a node are the dependents of that node, and the outgoing edges are
-- the dependencies of that node.
data DepGraph = DepGraph
  { DepGraph -> Gr NodeT EdgeT
dgGraph :: G.Gr NodeT EdgeT,
    DepGraph -> ProducerMapping
dgProducerMapping :: ProducerMapping,
    -- | A table mapping VNames to VNames that are aliased to it.
    DepGraph -> AliasTable
dgAliasTable :: AliasTable
  }

-- | A "graph augmentation" is a monadic action that modifies the graph.
type DepGraphAug m = DepGraph -> m DepGraph

-- | For each node, what producer should the node depend on and what
-- type is it.
type EdgeGenerator = NodeT -> [(VName, EdgeT)]

-- | A mapping from variable name to the graph node that produces
-- it.
type ProducerMapping = M.Map VName G.Node

makeMapping :: Monad m => DepGraphAug m
makeMapping :: forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping dg :: DepGraph
dg@(DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg {dgProducerMapping :: ProducerMapping
dgProducerMapping = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepNode -> [(VName, Int)]
gen_dep_list (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes Gr NodeT EdgeT
g)}
  where
    gen_dep_list :: DepNode -> [(VName, G.Node)]
    gen_dep_list :: DepNode -> [(VName, Int)]
gen_dep_list (Int
i, NodeT
node) = [(VName
name, Int
i) | VName
name <- NodeT -> [VName]
getOutputs NodeT
node]

-- | Apply several graph augmentations in sequence.
applyAugs :: Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs :: forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs DepGraph
g = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
($)) DepGraph
g [DepGraphAug m]
augs

-- | Creates deps for the given nodes on the graph using the 'EdgeGenerator'.
genEdges :: Monad m => [DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges :: forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges [DepNode]
l_stms EdgeGenerator
edge_fun DepGraph
dg =
  forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ProducerMapping -> DepNode -> [DepEdge]
genEdge (DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg)) [DepNode]
l_stms) DepGraph
dg
  where
    -- statements -> mapping from declared array names to soac index
    genEdge :: M.Map VName G.Node -> DepNode -> [G.LEdge EdgeT]
    genEdge :: ProducerMapping -> DepNode -> [DepEdge]
genEdge ProducerMapping
name_map (Int
from, NodeT
node) = do
      (VName
dep, EdgeT
edgeT) <- EdgeGenerator
edge_fun NodeT
node
      Just Int
to <- [forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
dep ProducerMapping
name_map]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b. Edge -> b -> LEdge b
G.toLEdge (Int
from, Int
to) EdgeT
edgeT

depGraphInsertEdges :: Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges :: forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges [DepEdge]
edgs DepGraph
dg = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = forall (gr :: * -> * -> *) b a.
DynGraph gr =>
[LEdge b] -> gr a b -> gr a b
G.insEdges [DepEdge]
edgs forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg}

-- | Monadically modify every node of the graph.
mapAcross :: Monad m => (DepContext -> m DepContext) -> DepGraphAug m
mapAcross :: forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross DepContext -> m DepContext
f DepGraph
dg = do
  Gr NodeT EdgeT
g' <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {gr :: * -> * -> *}.
DynGraph gr =>
Int -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper) (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg) (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Int]
G.nodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg))
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g'}
  where
    helper :: Int -> gr NodeT EdgeT -> m (gr NodeT EdgeT)
helper Int
n gr NodeT EdgeT
g' = case forall (gr :: * -> * -> *) a b.
Graph gr =>
Int -> gr a b -> Decomp gr a b
G.match Int
n gr NodeT EdgeT
g' of
      (Just DepContext
c, gr NodeT EdgeT
g_new) -> do
        DepContext
c' <- DepContext -> m DepContext
f DepContext
c
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepContext
c' forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& gr NodeT EdgeT
g_new
      (MContext NodeT EdgeT
Nothing, gr NodeT EdgeT
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure gr NodeT EdgeT
g'

stmFromNode :: NodeT -> Stms SOACS -- do not use outside of edge generation
stmFromNode :: NodeT -> Stms SOACS
stmFromNode (StmNode Stm SOACS
x) = forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
x
stmFromNode NodeT
_ = forall a. Monoid a => a
mempty

-- | Get the underlying @fgl@ node.
nodeFromLNode :: DepNode -> G.Node
nodeFromLNode :: DepNode -> Int
nodeFromLNode = forall a b. (a, b) -> a
fst

-- | Get the variable name that this edge refers to.
depsFromEdge :: DepEdge -> VName
depsFromEdge :: DepEdge -> VName
depsFromEdge = EdgeT -> VName
getName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b. LEdge b -> b
G.edgeLabel

-- | Find all the edges connecting the two nodes.
edgesBetween :: DepGraph -> G.Node -> G.Node -> [DepEdge]
edgesBetween :: DepGraph -> Int -> Int -> [DepEdge]
edgesBetween DepGraph
dg Int
n1 Int
n2 = forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b.
DynGraph gr =>
[Int] -> gr a b -> gr a b
G.subgraph [Int
n1, Int
n2] forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg

-- | @reachable dg from to@ is true if @to@ is reachable from @from@.
reachable :: DepGraph -> G.Node -> G.Node -> Bool
reachable :: DepGraph -> Int -> Int -> Bool
reachable DepGraph
dg Int
source Int
target = Int
target forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> [Int]
Q.reachable Int
source (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)

-- Utility func for augs
augWithFun :: Monad m => EdgeGenerator -> DepGraphAug m
augWithFun :: forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
f DepGraph
dg = forall (m :: * -> *).
Monad m =>
[DepNode] -> EdgeGenerator -> DepGraphAug m
genEdges (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) EdgeGenerator
f DepGraph
dg

addDeps :: Monad m => DepGraphAug m
addDeps :: forall (m :: * -> *). Monad m => DepGraphAug m
addDeps = forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
toDep
  where
    toDep :: EdgeGenerator
toDep NodeT
stmt =
      let ([VName]
fusible, [VName]
infusible) =
            forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst)
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition ((forall a. Eq a => a -> a -> Bool
== Classification
SOACInput) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList
              forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs (NodeT -> Stms SOACS
stmFromNode NodeT
stmt)
          mkDep :: VName -> (VName, EdgeT)
mkDep VName
vname = (VName
vname, VName -> EdgeT
Dep VName
vname)
          mkInfDep :: VName -> (VName, EdgeT)
mkInfDep VName
vname = (VName
vname, VName -> EdgeT
InfDep VName
vname)
       in forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkDep [VName]
fusible forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map VName -> (VName, EdgeT)
mkInfDep [VName]
infusible

addConsAndAliases :: Monad m => DepGraphAug m
addConsAndAliases :: forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases = forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
edges
  where
    edges :: EdgeGenerator
edges (StmNode Stm SOACS
s) = forall {rep}. Aliased rep => Stm rep -> [(VName, EdgeT)]
consEdges Stm (Aliases SOACS)
s' forall a. Semigroup a => a -> a -> a
<> Stm (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges Stm (Aliases SOACS)
s'
      where
        s' :: Stm (Aliases SOACS)
s' = forall rep.
AliasableRep rep =>
AliasTable -> Stm rep -> Stm (Aliases rep)
Alias.analyseStm forall a. Monoid a => a
mempty Stm SOACS
s
    edges NodeT
_ = forall a. Monoid a => a
mempty
    consEdges :: Stm rep -> [(VName, EdgeT)]
consEdges Stm rep
s = forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names (forall a b. (a -> b) -> [a] -> [b]
map VName -> EdgeT
Cons [VName]
names)
      where
        names :: [VName]
names = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Stm rep -> Names
consumedInStm Stm rep
s
    aliasEdges :: Stm (Aliases SOACS) -> [(VName, EdgeT)]
aliasEdges =
      forall a b. (a -> b) -> [a] -> [b]
map (\VName
vname -> (VName
vname, VName -> EdgeT
Alias VName
vname))
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. AliasesOf dec => Pat dec -> [Names]
patAliases
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat

-- extra dependencies mask the fact that consuming nodes "depend" on all other
-- nodes coming before it (now also adds fake edges to aliases - hope this
-- fixes asymptotic complexity guarantees)
addExtraCons :: Monad m => DepGraphAug m
addExtraCons :: forall (m :: * -> *). Monad m => DepGraphAug m
addExtraCons DepGraph
dg =
  forall (m :: * -> *). Monad m => [DepEdge] -> DepGraphAug m
depGraphInsertEdges (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DepEdge -> [DepEdge]
makeEdge (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LEdge b]
G.labEdges Gr NodeT EdgeT
g)) DepGraph
dg
  where
    g :: Gr NodeT EdgeT
g = DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
    alias_table :: AliasTable
alias_table = DepGraph -> AliasTable
dgAliasTable DepGraph
dg
    mapping :: ProducerMapping
mapping = DepGraph -> ProducerMapping
dgProducerMapping DepGraph
dg
    makeEdge :: DepEdge -> [DepEdge]
makeEdge (Int
from, Int
to, Cons VName
cname) = do
      let aliases :: [VName]
aliases = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
cname AliasTable
alias_table
          to' :: [Int]
to' = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` ProducerMapping
mapping) [VName]
aliases
          p :: (Int, EdgeT) -> Bool
p (Int
tonode, EdgeT
toedge) =
            Int
tonode forall a. Eq a => a -> a -> Bool
/= Int
from Bool -> Bool -> Bool
&& EdgeT -> VName
getName EdgeT
toedge forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (VName
cname forall a. a -> [a] -> [a]
: [VName]
aliases)
      (Int
to2, EdgeT
_) <- forall a. (a -> Bool) -> [a] -> [a]
filter (Int, EdgeT) -> Bool
p forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g) [Int]
to' forall a. Semigroup a => a -> a -> a
<> forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g Int
to
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b. Edge -> b -> LEdge b
G.toLEdge (Int
from, Int
to2) (VName -> EdgeT
Fake VName
cname)
    makeEdge DepEdge
_ = []

mapAcrossNodeTs :: Monad m => (NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs :: forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs NodeT -> m NodeT
f = forall (m :: * -> *).
Monad m =>
(DepContext -> m DepContext) -> DepGraphAug m
mapAcross forall {a} {b} {d}. (a, b, NodeT, d) -> m (a, b, NodeT, d)
f'
  where
    f' :: (a, b, NodeT, d) -> m (a, b, NodeT, d)
f' (a
ins, b
n, NodeT
nodeT, d
outs) = do
      NodeT
nodeT' <- NodeT -> m NodeT
f NodeT
nodeT
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
ins, b
n, NodeT
nodeT', d
outs)

nodeToSoacNode :: (HasScope SOACS m, Monad m) => NodeT -> m NodeT
nodeToSoacNode :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode n :: NodeT
n@(StmNode s :: Stm SOACS
s@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
op)) = case Exp SOACS
op of
  Op {} -> do
    Either NotSOAC (SOAC SOACS)
maybeSoac <- forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
H.fromExp Exp SOACS
op
    case Either NotSOAC (SOAC SOACS)
maybeSoac of
      Right SOAC SOACS
hsoac -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode forall a. Monoid a => a
mempty Pat (LetDec SOACS)
pat SOAC SOACS
hsoac StmAux (ExpDec SOACS)
aux
      Left NotSOAC
H.NotSOAC -> forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
  DoLoop {} ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode Stm SOACS
s []
  Match {} ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
s []
  Exp SOACS
e
    | [VName
output] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat,
      Just (VName
ia, ArrayTransform
tr) <- forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) Exp SOACS
e ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> ArrayTransform -> VName -> NodeT
TransNode VName
output ArrayTransform
tr VName
ia
  Exp SOACS
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n
nodeToSoacNode NodeT
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
n

-- | Construct a graph with only nodes, but no edges.
emptyGraph :: Body SOACS -> DepGraph
emptyGraph :: Body SOACS -> DepGraph
emptyGraph Body SOACS
body =
  DepGraph
    { dgGraph :: Gr NodeT EdgeT
dgGraph = forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
G.mkGraph (forall {b}. [b] -> [(Int, b)]
labelNodes ([NodeT]
stmnodes forall a. Semigroup a => a -> a -> a
<> [NodeT]
resnodes forall a. Semigroup a => a -> a -> a
<> [NodeT]
inputnodes)) [],
      dgProducerMapping :: ProducerMapping
dgProducerMapping = forall a. Monoid a => a
mempty,
      dgAliasTable :: AliasTable
dgAliasTable = AliasTable
aliases
    }
  where
    labelNodes :: [b] -> [(Int, b)]
labelNodes = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
    stmnodes :: [NodeT]
stmnodes = forall a b. (a -> b) -> [a] -> [b]
map Stm SOACS -> NodeT
StmNode forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
    resnodes :: [NodeT]
resnodes = forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
ResNode forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body SOACS
body
    inputnodes :: [NodeT]
inputnodes = forall a b. (a -> b) -> [a] -> [b]
map VName -> NodeT
FreeNode forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed
    (Stms (Aliases SOACS)
_, (AliasTable
aliases, Names
consumed)) = forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body

getStmRes :: EdgeGenerator
getStmRes :: EdgeGenerator
getStmRes (ResNode VName
name) = [(VName
name, VName -> EdgeT
Res VName
name)]
getStmRes NodeT
_ = []

addResEdges :: Monad m => DepGraphAug m
addResEdges :: forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges = forall (m :: * -> *). Monad m => EdgeGenerator -> DepGraphAug m
augWithFun EdgeGenerator
getStmRes

-- | Make a dependency graph corresponding to a 'Body'.
mkDepGraph :: (HasScope SOACS m, Monad m) => Body SOACS -> m DepGraph
mkDepGraph :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body = forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug m]
augs forall a b. (a -> b) -> a -> b
$ Body SOACS -> DepGraph
emptyGraph Body SOACS
body
  where
    augs :: [DepGraphAug m]
augs =
      [ forall (m :: * -> *). Monad m => DepGraphAug m
makeMapping,
        forall (m :: * -> *). Monad m => DepGraphAug m
addDeps,
        forall (m :: * -> *). Monad m => DepGraphAug m
addConsAndAliases,
        forall (m :: * -> *). Monad m => DepGraphAug m
addExtraCons,
        forall (m :: * -> *). Monad m => DepGraphAug m
addResEdges,
        forall (m :: * -> *).
Monad m =>
(NodeT -> m NodeT) -> DepGraphAug m
mapAcrossNodeTs forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
NodeT -> m NodeT
nodeToSoacNode -- Must be done after adding edges
      ]

-- | Make a dependency graph corresponding to a function.
mkDepGraphForFun :: FunDef SOACS -> DepGraph
mkDepGraphForFun :: FunDef SOACS -> DepGraph
mkDepGraphForFun FunDef SOACS
f = forall r a. Reader r a -> r -> a
runReader (forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph (forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
f)) Scope SOACS
scope
  where
    scope :: Scope SOACS
scope = forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef SOACS
f) forall a. Semigroup a => a -> a -> a
<> forall rep a. Scoped rep a => a -> Scope rep
scopeOf (forall rep. Body rep -> Stms rep
bodyStms (forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
f))

-- | Merges two contexts.
mergedContext :: Ord b => a -> G.Context a b -> G.Context a b -> G.Context a b
mergedContext :: forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext a
mergedlabel (Adj b
inp1, Int
n1, a
_, Adj b
out1) (Adj b
inp2, Int
n2, a
_, Adj b
out2) =
  let new_inp :: Adj b
new_inp = forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Int)
n -> forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n1 Bool -> Bool -> Bool
&& forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n2) (forall a. Ord a => [a] -> [a]
nubOrd (Adj b
inp1 forall a. Semigroup a => a -> a -> a
<> Adj b
inp2))
      new_out :: Adj b
new_out = forall a. (a -> Bool) -> [a] -> [a]
filter (\(b, Int)
n -> forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n1 Bool -> Bool -> Bool
&& forall a b. (a, b) -> b
snd (b, Int)
n forall a. Eq a => a -> a -> Bool
/= Int
n2) (forall a. Ord a => [a] -> [a]
nubOrd (Adj b
out1 forall a. Semigroup a => a -> a -> a
<> Adj b
out2))
   in (Adj b
new_inp, Int
n1, a
mergedlabel, Adj b
new_out)

-- | Remove the given node, and insert the 'DepContext' into the
-- graph, replacing any existing information about the node contained
-- in the 'DepContext'.
contractEdge :: Monad m => G.Node -> DepContext -> DepGraphAug m
contractEdge :: forall (m :: * -> *). Monad m => Int -> DepContext -> DepGraphAug m
contractEdge Int
n2 DepContext
ctx DepGraph
dg = do
  let n1 :: Int
n1 = forall a b. Context a b -> Int
G.node' DepContext
ctx -- n1 remains
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DepGraph
dg {dgGraph :: Gr NodeT EdgeT
dgGraph = DepContext
ctx forall (gr :: * -> * -> *) a b.
DynGraph gr =>
Context a b -> gr a b -> gr a b
G.& forall (gr :: * -> * -> *) a b.
Graph gr =>
[Int] -> gr a b -> gr a b
G.delNodes [Int
n1, Int
n2] (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)}

-- Utils for fusibility/infusibility
-- find dependencies - either fusible or infusible. edges are generated based on these

-- | A classification of a free variable.
data Classification
  = -- | Used as array input to a SOAC (meaning fusible).
    SOACInput
  | -- | Used in some other way.
    Other
  deriving (Classification -> Classification -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Classification -> Classification -> Bool
$c/= :: Classification -> Classification -> Bool
== :: Classification -> Classification -> Bool
$c== :: Classification -> Classification -> Bool
Eq, Eq Classification
Classification -> Classification -> Bool
Classification -> Classification -> Ordering
Classification -> Classification -> Classification
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Classification -> Classification -> Classification
$cmin :: Classification -> Classification -> Classification
max :: Classification -> Classification -> Classification
$cmax :: Classification -> Classification -> Classification
>= :: Classification -> Classification -> Bool
$c>= :: Classification -> Classification -> Bool
> :: Classification -> Classification -> Bool
$c> :: Classification -> Classification -> Bool
<= :: Classification -> Classification -> Bool
$c<= :: Classification -> Classification -> Bool
< :: Classification -> Classification -> Bool
$c< :: Classification -> Classification -> Bool
compare :: Classification -> Classification -> Ordering
$ccompare :: Classification -> Classification -> Ordering
Ord, Int -> Classification -> ShowS
[Classification] -> ShowS
Classification -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Classification] -> ShowS
$cshowList :: [Classification] -> ShowS
show :: Classification -> String
$cshow :: Classification -> String
showsPrec :: Int -> Classification -> ShowS
$cshowsPrec :: Int -> Classification -> ShowS
Show)

type Classifications = S.Set (VName, Classification)

freeClassifications :: FreeIn a => a -> Classifications
freeClassifications :: forall a. FreeIn a => a -> Classifications
freeClassifications =
  forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. [a] -> [b] -> [(a, b)]
`zip` forall a. a -> [a]
repeat Classification
Other) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn

stmInputs :: Stm SOACS -> Classifications
stmInputs :: Stm SOACS -> Classifications
stmInputs (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
  forall a. FreeIn a => a -> Classifications
freeClassifications (Pat (LetDec SOACS)
pat, StmAux (ExpDec SOACS)
aux) forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> Classifications
expInputs Exp SOACS
e

bodyInputs :: Body SOACS -> Classifications
bodyInputs :: Body SOACS -> Classifications
bodyInputs (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm SOACS -> Classifications
stmInputs Stms SOACS
stms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications Result
res

expInputs :: Exp SOACS -> Classifications
expInputs :: Exp SOACS -> Classifications
expInputs (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
attr) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body SOACS -> Classifications
bodyInputs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body SOACS)]
cases
    forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
defbody
    forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications ([SubExp]
cond, MatchDec (BranchType SOACS)
attr)
expInputs (DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
b1) =
  forall a. FreeIn a => a -> Classifications
freeClassifications ([(FParam SOACS, SubExp)]
params, LoopForm SOACS
form) forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Classifications
bodyInputs Body SOACS
b1
expInputs (Op Op SOACS
soac) = case Op SOACS
soac of
  Futhark.Screma SubExp
w [VName]
is ScremaForm SOACS
form -> [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, ScremaForm SOACS
form)
  Futhark.Hist SubExp
w [VName]
is [HistOp SOACS]
ops Lambda SOACS
lam -> [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [HistOp SOACS]
ops, Lambda SOACS
lam)
  Futhark.Scatter SubExp
w [VName]
is Lambda SOACS
lam [(Shape, Int, VName)]
iws -> [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, Lambda SOACS
lam, [(Shape, Int, VName)]
iws)
  Futhark.Stream SubExp
w [VName]
is [SubExp]
nes Lambda SOACS
lam ->
    [VName] -> Classifications
inputs [VName]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (SubExp
w, [SubExp]
nes, Lambda SOACS
lam)
  Futhark.JVP {} -> forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
soac
  Futhark.VJP {} -> forall a. FreeIn a => a -> Classifications
freeClassifications Op SOACS
soac
  where
    inputs :: [VName] -> Classifications
inputs = forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. [a] -> [b] -> [(a, b)]
`zip` forall a. a -> [a]
repeat Classification
SOACInput)
expInputs Exp SOACS
e
  | Just (VName
arr, ArrayTransform
_) <- forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp forall a. Monoid a => a
mempty Exp SOACS
e =
      forall a. a -> Set a
S.singleton (VName
arr, Classification
SOACInput)
        forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Classifications
freeClassifications (forall a. FreeIn a => a -> Names
freeIn Exp SOACS
e Names -> Names -> Names
`namesSubtract` VName -> Names
oneName VName
arr)
  | Bool
otherwise = forall a. FreeIn a => a -> Classifications
freeClassifications Exp SOACS
e

stmNames :: Stm SOACS -> [VName]
stmNames :: Stm SOACS -> [VName]
stmNames = forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat

getOutputs :: NodeT -> [VName]
getOutputs :: NodeT -> [VName]
getOutputs NodeT
node = case NodeT
node of
  (StmNode Stm SOACS
stm) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
  (TransNode VName
v ArrayTransform
_ VName
_) -> [VName
v]
  (ResNode VName
_) -> []
  (FreeNode VName
name) -> [VName
name]
  (MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
  (DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
_) -> Stm SOACS -> [VName]
stmNames Stm SOACS
stm
  (SoacNode ArrayTransforms
_ Pat Type
pat SOAC SOACS
_ StmAux (ExpDec SOACS)
_) -> forall dec. Pat dec -> [VName]
patNames Pat Type
pat

-- | Is there a possibility of fusion?
isDep :: EdgeT -> Bool
isDep :: EdgeT -> Bool
isDep (Dep VName
_) = Bool
True
isDep (Res VName
_) = Bool
True
isDep EdgeT
_ = Bool
False

-- | Is this an infusible edge?
isInf :: (G.Node, G.Node, EdgeT) -> Bool
isInf :: DepEdge -> Bool
isInf (Int
_, Int
_, EdgeT
e) = case EdgeT
e of
  InfDep VName
_ -> Bool
True
  Fake VName
_ -> Bool
True -- this is infusible to avoid simultaneous cons/dep edges
  EdgeT
_ -> Bool
False

-- | Is this a 'Cons' edge?
isCons :: EdgeT -> Bool
isCons :: EdgeT -> Bool
isCons (Cons VName
_) = Bool
True
isCons EdgeT
_ = Bool
False