-- | Perform horizontal and vertical fusion of SOACs.  See the paper
-- /A T2 Graph-Reduction Approach To Fusion/ for the basic idea (some
-- extensions discussed in /Design and GPGPU Performance of Futhark’s
-- Redomap Construct/).
module Futhark.Optimise.Fusion (fuseSOACs) where

import Control.Monad.Reader
import Control.Monad.State
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.Construct
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Optimise.Fusion.GraphRep
import Futhark.Optimise.Fusion.TryFusion qualified as TF
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute

data FusionEnv = FusionEnv
  { FusionEnv -> VNameSource
vNameSource :: VNameSource,
    FusionEnv -> Int
fusionCount :: Int,
    FusionEnv -> Bool
fuseScans :: Bool
  }

freshFusionEnv :: FusionEnv
freshFusionEnv :: FusionEnv
freshFusionEnv =
  FusionEnv
    { vNameSource :: VNameSource
vNameSource = VNameSource
blankNameSource,
      fusionCount :: Int
fusionCount = Int
0,
      fuseScans :: Bool
fuseScans = Bool
True
    }

newtype FusionM a = FusionM (ReaderT (Scope SOACS) (State FusionEnv) a)
  deriving
    ( Applicative FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> FusionM a
$creturn :: forall a. a -> FusionM a
>> :: forall a b. FusionM a -> FusionM b -> FusionM b
$c>> :: forall a b. FusionM a -> FusionM b -> FusionM b
>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
$c>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
Monad,
      Functor FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. FusionM a -> FusionM b -> FusionM a
$c<* :: forall a b. FusionM a -> FusionM b -> FusionM a
*> :: forall a b. FusionM a -> FusionM b -> FusionM b
$c*> :: forall a b. FusionM a -> FusionM b -> FusionM b
liftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
$cliftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
$c<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
pure :: forall a. a -> FusionM a
$cpure :: forall a. a -> FusionM a
Applicative,
      forall a b. a -> FusionM b -> FusionM a
forall a b. (a -> b) -> FusionM a -> FusionM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> FusionM b -> FusionM a
$c<$ :: forall a b. a -> FusionM b -> FusionM a
fmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
$cfmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
Functor,
      MonadState FusionEnv,
      HasScope SOACS,
      LocalScope SOACS
    )

instance MonadFreshNames FusionM where
  getNameSource :: FusionM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> VNameSource
vNameSource
  putNameSource :: VNameSource -> FusionM ()
putNameSource VNameSource
source =
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
env -> FusionEnv
env {vNameSource :: VNameSource
vNameSource = VNameSource
source})

runFusionM :: MonadFreshNames m => Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM :: forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM Scope SOACS
scope FusionEnv
fenv (FusionM ReaderT (Scope SOACS) (State FusionEnv) a
a) = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  let x :: State FusionEnv a
x = forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (State FusionEnv) a
a Scope SOACS
scope
      (a
y, FusionEnv
z) = forall s a. State s a -> s -> (a, s)
runState State FusionEnv a
x (FusionEnv
fenv {vNameSource :: VNameSource
vNameSource = VNameSource
src})
   in (a
y, FusionEnv -> VNameSource
vNameSource FusionEnv
z)

doFuseScans :: FusionM a -> FusionM a
doFuseScans :: forall a. FusionM a -> FusionM a
doFuseScans FusionM a
m = do
  Bool
fs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
True})
  a
r <- FusionM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
fs})
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r

dontFuseScans :: FusionM a -> FusionM a
dontFuseScans :: forall a. FusionM a -> FusionM a
dontFuseScans FusionM a
m = do
  Bool
fs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
False})
  a
r <- FusionM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
fs})
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r

unreachableEitherDir :: DepGraph -> G.Node -> G.Node -> Bool
unreachableEitherDir :: DepGraph -> Int -> Int -> Bool
unreachableEitherDir DepGraph
g Int
a Int
b =
  Bool -> Bool
not (DepGraph -> Int -> Int -> Bool
reachable DepGraph
g Int
a Int
b Bool -> Bool -> Bool
|| DepGraph -> Int -> Int -> Bool
reachable DepGraph
g Int
b Int
a)

isNotVarInput :: [H.Input] -> [H.Input]
isNotVarInput :: [Input] -> [Input]
isNotVarInput = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Maybe VName
H.isVarInput)

finalizeNode :: (HasScope SOACS m, MonadFreshNames m) => NodeT -> m (Stms SOACS)
finalizeNode :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode NodeT
nt = case NodeT
nt of
  StmNode Stm SOACS
stm -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
stm
  SoacNode ArrayTransforms
ots Pat Type
outputs SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
    [VName]
untransformed_outputs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
outputs
    forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
untransformed_outputs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
H.toSOAC SOAC SOACS
soac
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
outputs) [VName]
untransformed_outputs) forall a b. (a -> b) -> a -> b
$ \(VName
output, VName
v) ->
      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
H.applyTransforms ArrayTransforms
ots VName
v
  ResNode VName
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  FreeNode VName
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
    [Stms SOACS]
lst' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
stm
  MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
    [Stms SOACS]
lst' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
stm
  FinalNode Stms SOACS
stms1 NodeT
nt' Stms SOACS
stms2 -> do
    Stms SOACS
stms' <- forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode NodeT
nt'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms SOACS
stms1 forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms' forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms2

linearizeGraph :: (HasScope SOACS m, MonadFreshNames m) => DepGraph -> m (Stms SOACS)
linearizeGraph :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
dg =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [a]
Q.topsort' (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)

fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething NodeT
x = do
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \FusionEnv
s -> FusionEnv
s {fusionCount :: Int
fusionCount = Int
1 forall a. Num a => a -> a -> a
+ FusionEnv -> Int
fusionCount FusionEnv
s}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just NodeT
x

-- | For each node, find what came before, attempt to fuse them
-- horizontally.  This means we only perform horizontal fusion for
-- SOACs that use the same input in some way.
horizontalFusionOnNode :: G.Node -> DepGraphAug FusionM
horizontalFusionOnNode :: Int -> DepGraphAug FusionM
horizontalFusionOnNode Int
node dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} =
  forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph) [(Int, Int)]
pairs) DepGraph
dg
  where
    incoming_nodes :: [Int]
incoming_nodes = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isDep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g Int
node
    pairs :: [(Int, Int)]
pairs = [(Int
x, Int
y) | Int
x <- [Int]
incoming_nodes, Int
y <- [Int]
incoming_nodes, Int
x forall a. Ord a => a -> a -> Bool
< Int
y]

vFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
vFusionFeasability :: DepGraph -> Int -> Int -> Bool
vFusionFeasability dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} Int
n1 Int
n2 =
  Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int, Int, EdgeT) -> Bool
isInf (DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
n1 Int
n2))
    Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (DepGraph -> Int -> Int -> Bool
reachable DepGraph
dg Int
n2) (forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= Int
n2) (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
n1)))

hFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
hFusionFeasability :: DepGraph -> Int -> Int -> Bool
hFusionFeasability = DepGraph -> Int -> Int -> Bool
unreachableEitherDir

tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph DepNode
node_to_fuse dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} =
  if forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_to_fuse_id Gr NodeT EdgeT
g
    then forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs (forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_to_fuse_id) [Int]
fuses_with) DepGraph
dg
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  where
    fuses_with :: [Int]
fuses_with = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isDep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g (DepNode -> Int
nodeFromLNode DepNode
node_to_fuse)
    node_to_fuse_id :: Int
node_to_fuse_id = DepNode -> Int
nodeFromLNode DepNode
node_to_fuse

vTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
-- find the neighbors -> verify that fusion causes no cycles -> fuse
vTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
  | Bool -> Bool
not (forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | DepGraph -> Int -> Int -> Bool
vFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
      let (Context NodeT EdgeT
ctx1, Context NodeT EdgeT
ctx2) = (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1, forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
      Maybe (Context NodeT EdgeT)
fres <- [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable_nodes Context NodeT EdgeT
ctx1 Context NodeT EdgeT
ctx2
      case Maybe (Context NodeT EdgeT)
fres of
        Just (Adj EdgeT
inputs, Int
_, NodeT
nodeT, Adj EdgeT
outputs) -> do
          NodeT
nodeT' <-
            if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
fusedC
              then forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT
              else do
                let (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_1) = Context NodeT EdgeT
ctx1
                    (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_2) = Context NodeT EdgeT
ctx2
                    -- make copies of everything that was not
                    -- previously consumed
                    old_cons :: [VName]
old_cons = forall a b. (a -> b) -> [a] -> [b]
map (EdgeT -> VName
getName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isCons forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (Adj EdgeT
deps_1 forall a. Semigroup a => a -> a -> a
<> Adj EdgeT
deps_2)
                forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
old_cons NodeT
nodeT
          forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 (Adj EdgeT
inputs, Int
node_1, NodeT
nodeT', Adj EdgeT
outputs) DepGraph
dg
        Maybe (Context NodeT EdgeT)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  where
    edgs :: [EdgeT]
edgs = forall a b. (a -> b) -> [a] -> [b]
map forall b. LEdge b -> b
G.edgeLabel forall a b. (a -> b) -> a -> b
$ DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
node_1 Int
node_2
    fusedC :: [VName]
fusedC = forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
isCons [EdgeT]
edgs
    infusable_nodes :: [VName]
infusable_nodes =
      forall a b. (a -> b) -> [a] -> [b]
map
        (Int, Int, EdgeT) -> VName
depsFromEdge
        (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
node_1) (forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= Int
node_2) forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
node_1))

hTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
hTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
  | Bool -> Bool
not (forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | DepGraph -> Int -> Int -> Bool
hFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
      Maybe (Context NodeT EdgeT)
fres <- Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1) (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
      case Maybe (Context NodeT EdgeT)
fres of
        Just Context NodeT EdgeT
ctx -> forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 Context NodeT EdgeT
ctx DepGraph
dg
        Maybe (Context NodeT EdgeT)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg

hFuseContexts :: DepContext -> DepContext -> FusionM (Maybe DepContext)
hFuseContexts :: Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts
  c1 :: Context NodeT EdgeT
c1@(Adj EdgeT
_, Int
_, NodeT
nodeT1, Adj EdgeT
_)
  c2 :: Context NodeT EdgeT
c2@(Adj EdgeT
_, Int
_, NodeT
nodeT2, Adj EdgeT
_) = do
    Maybe NodeT
fres <- NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT NodeT
nodeT1 NodeT
nodeT2
    case Maybe NodeT
fres of
      Just NodeT
nodeT -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
      Maybe NodeT
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

vFuseContexts :: [EdgeT] -> [VName] -> DepContext -> DepContext -> FusionM (Maybe DepContext)
vFuseContexts :: [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts
  [EdgeT]
edgs
  [VName]
infusable
  c1 :: Context NodeT EdgeT
c1@(Adj EdgeT
i1, Int
n1, NodeT
nodeT1, Adj EdgeT
o1)
  c2 :: Context NodeT EdgeT
c2@(Adj EdgeT
_i2, Int
n2, NodeT
nodeT2, Adj EdgeT
o2) = do
    Maybe NodeT
fres <-
      [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT
        [EdgeT]
edgs
        [VName]
infusable
        (NodeT
nodeT1, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
(/=) Int
n2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) Adj EdgeT
i1, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst Adj EdgeT
o1)
        (NodeT
nodeT2, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
(/=) Int
n1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) Adj EdgeT
o2)
    case Maybe NodeT
fres of
      Just NodeT
nodeT -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
      Maybe NodeT
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

makeCopiesOfFusedExcept ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  NodeT ->
  m NodeT
makeCopiesOfFusedExcept :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
noCopy (SoacNode ArrayTransforms
ots Pat Type
pats SOAC SOACS
soac StmAux (ExpDec SOACS)
aux) = do
  let lam :: Lambda SOACS
lam = forall {k} (rep :: k). SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
  forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) forall a b. (a -> b) -> a -> b
$ do
    [VName]
fused_inner <-
      forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> Bool
isAcc) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Aliased rep => Lambda rep -> Names
consumedByLambda forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda forall a. Monoid a => a
mempty Lambda SOACS
lam
    Lambda SOACS
lam' <- forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda ([VName]
fused_inner forall a. Eq a => [a] -> [a] -> [a]
L.\\ [VName]
noCopy) Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pats (forall {k} (rep :: k). Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
makeCopiesOfFusedExcept [VName]
_ NodeT
nodeT = forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT

makeCopiesInLambda ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  Lambda SOACS ->
  m (Lambda SOACS)
makeCopiesInLambda :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda [VName]
toCopy Lambda SOACS
lam = do
  (Stms SOACS
copies, Map VName VName
nameMap) <- forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
toCopy
  let l_body :: Body SOACS
l_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
      newBody :: Body SOACS
newBody = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Body rep -> Body rep
insertStms Stms SOACS
copies (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
nameMap Body SOACS
l_body)
      newLambda :: Lambda SOACS
newLambda = Lambda SOACS
lam {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
newBody}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
newLambda

makeCopyStms ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  m (Stms SOACS, M.Map VName VName)
makeCopyStms :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
vs = do
  [VName]
vs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => VName -> m VName
makeNewName [VName]
vs
  [Stm SOACS]
copies <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs') forall a b. (a -> b) -> a -> b
$ \(VName
name, VName
name') ->
    forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m, HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
mkLetNames [VName
name'] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
name
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
copies, forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs')
  where
    makeNewName :: VName -> m VName
makeNewName VName
name = forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name forall a. Semigroup a => a -> a -> a
<> String
"_copy"

okToFuseProducer :: H.SOAC SOACS -> FusionM Bool
okToFuseProducer :: SOAC SOACS -> FusionM Bool
okToFuseProducer (H.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) = do
  let is_scan :: Bool
is_scan = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form
  forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ (Bool -> Bool
not Bool
is_scan ||) forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionEnv -> Bool
fuseScans
okToFuseProducer SOAC SOACS
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- First node is producer, second is consumer.
vFuseNodeT :: [EdgeT] -> [VName] -> (NodeT, [EdgeT], [EdgeT]) -> (NodeT, [EdgeT]) -> FusionM (Maybe NodeT)
vFuseNodeT :: [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT [EdgeT]
_ [VName]
infusible (NodeT
s1, [EdgeT]
_, [EdgeT]
e1s) (MatchNode Stm SOACS
stm2 [(NodeT, [EdgeT])]
dfused, [EdgeT]
_)
  | NodeT -> Bool
isRealNode NodeT
s1,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
stm2 forall a b. (a -> b) -> a -> b
$ (NodeT
s1, [EdgeT]
e1s) forall a. a -> [a] -> [a]
: [(NodeT, [EdgeT])]
dfused
vFuseNodeT [EdgeT]
_ [VName]
infusible (StmNode Stm SOACS
stm1, [EdgeT]
_, [EdgeT]
_) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_)
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible,
    [VName
stm1_out] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm1,
    Just (VName
stm1_in, ArrayTransform
tr) <-
      forall {k} (rep :: k).
Certs -> Exp rep -> Maybe (VName, ArrayTransform)
H.transformFromExp (forall dec. StmAux dec -> Certs
stmAuxCerts (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm1)) (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm SOACS
stm1) = do
      Type
stm1_in_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
stm1_in
      let onInput :: Input -> Input
onInput Input
inp
            | Input -> VName
H.inputArray Input
inp forall a. Eq a => a -> a -> Bool
== VName
stm1_out =
                ArrayTransforms -> VName -> Type -> Input
H.Input (ArrayTransform
tr ArrayTransform -> ArrayTransforms -> ArrayTransforms
H.<| Input -> ArrayTransforms
H.inputTransforms Input
inp) VName
stm1_in Type
stm1_in_t
            | Bool
otherwise =
                Input
inp
          soac2' :: SOAC SOACS
soac2' = forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
onInput (forall {k} (rep :: k). SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) forall {k} (rep :: k). [Input] -> SOAC rep -> SOAC rep
`H.setInputs` SOAC SOACS
soac2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2' StmAux (ExpDec SOACS)
aux2
vFuseNodeT
  [EdgeT]
_
  [VName]
_
  (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1, [EdgeT]
i1s, [EdgeT]
_e1s)
  (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_e2s) = do
    let ker :: FusedSOAC
ker =
          TF.FusedSOAC
            { fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
              fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = ArrayTransforms
ots2,
              fsOutNames :: [VName]
TF.fsOutNames = forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
            }
        preserveEdge :: EdgeT -> Bool
preserveEdge InfDep {} = Bool
True
        preserveEdge EdgeT
e = EdgeT -> Bool
isDep EdgeT
e
        preserve :: Names
preserve = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
preserveEdge [EdgeT]
i1s
    Bool
ok <- SOAC SOACS -> FusionM Bool
okToFuseProducer SOAC SOACS
soac1
    Maybe FusedSOAC
r <-
      if Bool
ok Bool -> Bool -> Bool
&& ArrayTransforms
ots1 forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
        then forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Names -> [VName] -> SOAC SOACS -> FusedSOAC -> m (Maybe FusedSOAC)
TF.attemptFusion Names
preserve (forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    case Maybe FusedSOAC
r of
      Just FusedSOAC
ker' -> do
        let pats2' :: [PatElem Type]
pats2' =
              forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (forall {k} (rep :: k). SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
        NodeT -> FusionM (Maybe NodeT)
fusedSomething forall a b. (a -> b) -> a -> b
$
          ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode
            (FusedSOAC -> ArrayTransforms
TF.fsOutputTransform FusedSOAC
ker')
            (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2')
            (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker')
            (StmAux (ExpDec SOACS)
aux1 forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec SOACS)
aux2)
      Maybe FusedSOAC
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
vFuseNodeT [EdgeT]
_ [VName]
_ (NodeT, [EdgeT], [EdgeT])
_ (NodeT, [EdgeT])
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

resFromLambda :: Lambda rep -> Result
resFromLambda :: forall {k} (rep :: k). Lambda rep -> Result
resFromLambda = forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody

hasNoDifferingInputs :: [H.Input] -> [H.Input] -> Bool
hasNoDifferingInputs :: [Input] -> [Input] -> Bool
hasNoDifferingInputs [Input]
is1 [Input]
is2 =
  let ([Input]
vs1, [Input]
vs2) = ([Input] -> [Input]
isNotVarInput [Input]
is1, [Input] -> [Input]
isNotVarInput forall a b. (a -> b) -> a -> b
$ [Input]
is2 forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Input]
is1)
   in forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [Input]
vs1 forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` [Input]
vs2

hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2)
  | ArrayTransforms
ots1 forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    ArrayTransforms
ots2 forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    [Input] -> [Input] -> Bool
hasNoDifferingInputs (forall {k} (rep :: k). SOAC rep -> [Input]
H.inputs SOAC SOACS
soac1) (forall {k} (rep :: k). SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) = do
      let ker :: FusedSOAC
ker =
            TF.FusedSOAC
              { fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
                fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = forall a. Monoid a => a
mempty,
                fsOutNames :: [VName]
TF.fsOutNames = forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
              }
          preserve :: Names
preserve = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
pats1
      Maybe FusedSOAC
r <- forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Names -> [VName] -> SOAC SOACS -> FusedSOAC -> m (Maybe FusedSOAC)
TF.attemptFusion Names
preserve (forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
      case Maybe FusedSOAC
r of
        Just FusedSOAC
ker' -> do
          let pats2' :: [PatElem Type]
pats2' =
                forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (forall {k} (rep :: k). SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
          NodeT -> FusionM (Maybe NodeT)
fusedSomething forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode forall a. Monoid a => a
mempty (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2') (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker') (StmAux (ExpDec SOACS)
aux1 forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec SOACS)
aux2)
        Maybe FusedSOAC
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
hFuseNodeT NodeT
_ NodeT
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
s = case NodeT
s of
  SoacNode ArrayTransforms
ots (Pat [PatElem Type]
pats1) soac :: SOAC SOACS
soac@(H.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_1 [Reduce SOACS]
red_1 Lambda SOACS
lam_1) [Input]
_) StmAux (ExpDec SOACS)
aux1 ->
    ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem Type]
pats_unchanged forall a. Semigroup a => a -> a -> a
<> [PatElem Type]
pats_new) (forall {k} (rep :: k). Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam_new SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux1
    where
      scan_output_size :: Int
scan_output_size = forall {k} (rep :: k). [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_1
      red_output_size :: Int
red_output_size = forall {k} (rep :: k). [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
red_1

      ([PatElem Type]
pats_unchanged, [PatElem Type]
pats_toChange) = forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size forall a. Num a => a -> a -> a
+ Int
red_output_size) [PatElem Type]
pats1
      ([(SubExpRes, Type)]
res_unchanged, [(SubExpRes, Type)]
res_toChange) = forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size forall a. Num a => a -> a -> a
+ Int
red_output_size) (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> Result
resFromLambda Lambda SOACS
lam_1) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam_1))

      ([PatElem Type]
pats_new, [(SubExpRes, Type)]
other) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem Type
x, (SubExpRes, Type)
_) -> forall dec. PatElem dec -> VName
patElemName PatElem Type
x forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
toKeep) (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pats_toChange [(SubExpRes, Type)]
res_toChange)
      (Result
results, [Type]
types) = forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, Type)]
res_unchanged forall a. [a] -> [a] -> [a]
++ [(SubExpRes, Type)]
other)
      lam_new :: Lambda SOACS
lam_new =
        Lambda SOACS
lam_1
          { lambdaReturnType :: [Type]
lambdaReturnType = [Type]
types,
            lambdaBody :: Body SOACS
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_1) {bodyResult :: Result
bodyResult = Result
results}
          }
  NodeT
node -> NodeT
node

vNameFromAdj :: G.Node -> (EdgeT, G.Node) -> VName
vNameFromAdj :: Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1 (EdgeT
edge, Int
n2) = (Int, Int, EdgeT) -> VName
depsFromEdge (Int
n2, Int
n1, EdgeT
edge)

removeUnusedOutputsFromContext :: DepContext -> FusionM DepContext
removeUnusedOutputsFromContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext (Adj EdgeT
incoming, Int
n1, NodeT
nodeT, Adj EdgeT
outgoing) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
n1, NodeT
nodeT', Adj EdgeT
outgoing)
  where
    toKeep :: [VName]
toKeep = forall a b. (a -> b) -> [a] -> [b]
map (Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1) Adj EdgeT
incoming
    nodeT' :: NodeT
nodeT' = [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
nodeT

removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs = forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext

doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion DepGraph
dg = forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs (forall a b. (a -> b) -> [a] -> [b]
map DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) DepGraph
dg

doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion DepGraph
dg = forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs (forall a b. (a -> b) -> [a] -> [b]
map Int -> DepGraphAug FusionM
horizontalFusionOnNode (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Int]
G.nodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg))) DepGraph
dg

doInnerFusion :: DepGraphAug FusionM
doInnerFusion :: DepGraphAug FusionM
doInnerFusion = forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext

-- Fixed-point iteration.
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g = do
  Int
prev_fused <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
  DepGraph
g' <- DepGraphAug FusionM
f DepGraph
g
  Int
aft_fused <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
  if Int
prev_fused forall a. Eq a => a -> a -> Bool
/= Int
aft_fused then DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g' else forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
g'

doAllFusion :: DepGraphAug FusionM
doAllFusion :: DepGraphAug FusionM
doAllFusion =
  forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs
    [ DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs forall a b. (a -> b) -> a -> b
$
        [ DepGraphAug FusionM
doVerticalFusion,
          DepGraphAug FusionM
doHorizontalFusion,
          DepGraphAug FusionM
doInnerFusion
        ],
      DepGraphAug FusionM
removeUnusedOutputs
    ]

runInnerFusionOnContext :: DepContext -> FusionM DepContext
runInnerFusionOnContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext c :: Context NodeT EdgeT
c@(Adj EdgeT
incoming, Int
node, NodeT
nodeT, Adj EdgeT
outgoing) = case NodeT
nodeT of
  DoNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
body)) [(NodeT, [EdgeT])]
to_fuse ->
    forall a. FusionM a -> FusionM a
doFuseScans forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
params) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
form) forall a b. (a -> b) -> a -> b
$ do
      Body SOACS
b <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
body [(NodeT, [EdgeT])]
to_fuse
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
b)) [], Adj EdgeT
outgoing)
  MatchNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
dec)) [(NodeT, [EdgeT])]
to_fuse -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    [Case (Body SOACS)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
`doFusionWithDelayed` [(NodeT, [EdgeT])]
to_fuse)) [Case (Body SOACS)]
cases
    Body SOACS
defbody' <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
defbody [(NodeT, [EdgeT])]
to_fuse
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body SOACS)]
cases' Body SOACS
defbody' MatchDec (BranchType SOACS)
dec)) [], Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.VJP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
Futhark.VJP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec))), Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.JVP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
Futhark.JVP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec))), Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam')), Adj EdgeT
outgoing)
  SoacNode ArrayTransforms
ots Pat Type
pat SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> do
    let lam :: Lambda SOACS
lam = forall {k} (rep :: k). SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
    Lambda SOACS
lam' <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) forall a b. (a -> b) -> a -> b
$ case SOAC SOACS
soac of
      H.Stream {} ->
        forall a. FusionM a -> FusionM a
dontFuseScans forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
      SOAC SOACS
_ ->
        forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    let nodeT' :: NodeT
nodeT' = ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pat (forall {k} (rep :: k). Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, NodeT
nodeT', Adj EdgeT
outgoing)
  NodeT
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Context NodeT EdgeT
c
  where
    doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
    doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed (Body () Stms SOACS
stms Result
res) [(NodeT, [EdgeT])]
extraNodes = forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms) forall a b. (a -> b) -> a -> b
$ do
      [Stms SOACS]
stm_node <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
extraNodes
      Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
stm_node forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms) Result
res)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' Result
res
    doFusionBody :: Body SOACS -> FusionM (Body SOACS)
    doFusionBody :: Body SOACS -> FusionM (Body SOACS)
doFusionBody Body SOACS
body = do
      Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms'}
    doFusionLambda :: Lambda SOACS -> FusionM (Lambda SOACS)
    doFusionLambda :: Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam = do
      -- To clean up previous instances of fusion.
      Lambda SOACS
lam' <- forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
      Int
prev_count <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
      Body SOACS
newbody <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam') forall a b. (a -> b) -> a -> b
$ Body SOACS -> FusionM (Body SOACS)
doFusionBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
      Int
aft_count <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
      -- To clean up any inner fusion.
      (if Int
prev_count forall a. Eq a => a -> a -> Bool
/= Int
aft_count then forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda else forall (f :: * -> *) a. Applicative f => a -> f a
pure)
        Lambda SOACS
lam' {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
newbody}

-- main fusion function.
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body = forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body)) forall a b. (a -> b) -> a -> b
$ do
  DepGraph
graph_not_fused <- forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body
  DepGraph
graph_fused <- DepGraphAug FusionM
doAllFusion DepGraph
graph_not_fused
  forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
graph_fused

fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts [VName]
outputs Stms SOACS
stms =
  forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
    (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms)
    FusionEnv
freshFusionEnv
    (Body SOACS -> FusionM (Stms SOACS)
fuseGraph (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms ([VName] -> Result
varsRes [VName]
outputs)))

fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
  Stms SOACS
fun_stms' <-
    forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
      (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fun forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts)
      FusionEnv
freshFusionEnv
      (Body SOACS -> FusionM (Stms SOACS)
fuseGraph (forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef SOACS
fun))
  forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef SOACS
fun {funDefBody :: Body SOACS
funDefBody = (forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
fun_stms'}}

-- | The pass definition.
{-# NOINLINE fuseSOACs #-}
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
  Pass
    { passName :: String
passName = String
"Fuse SOACs",
      passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
p ->
        forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          ([VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog SOACS
p)))
          Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
          Prog SOACS
p
    }