{-# LANGUAGE Strict #-}

-- | Perform horizontal and vertical fusion of SOACs.  See the paper
-- /A T2 Graph-Reduction Approach To Fusion/ for the basic idea (some
-- extensions discussed in /Design and GPGPU Performance of Futhark’s
-- Redomap Construct/).
module Futhark.Optimise.Fusion (fuseSOACs) where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.Construct
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Optimise.Fusion.GraphRep
import Futhark.Optimise.Fusion.TryFusion qualified as TF
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute

data FusionEnv = FusionEnv
  { FusionEnv -> VNameSource
vNameSource :: VNameSource,
    FusionEnv -> Int
fusionCount :: Int,
    FusionEnv -> Bool
fuseScans :: Bool
  }

freshFusionEnv :: FusionEnv
freshFusionEnv :: FusionEnv
freshFusionEnv =
  FusionEnv
    { vNameSource :: VNameSource
vNameSource = VNameSource
blankNameSource,
      fusionCount :: Int
fusionCount = Int
0,
      fuseScans :: Bool
fuseScans = Bool
True
    }

newtype FusionM a = FusionM (ReaderT (Scope SOACS) (State FusionEnv) a)
  deriving
    ( Applicative FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> FusionM a
$creturn :: forall a. a -> FusionM a
>> :: forall a b. FusionM a -> FusionM b -> FusionM b
$c>> :: forall a b. FusionM a -> FusionM b -> FusionM b
>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
$c>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
Monad,
      Functor FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. FusionM a -> FusionM b -> FusionM a
$c<* :: forall a b. FusionM a -> FusionM b -> FusionM a
*> :: forall a b. FusionM a -> FusionM b -> FusionM b
$c*> :: forall a b. FusionM a -> FusionM b -> FusionM b
liftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
$cliftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
$c<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
pure :: forall a. a -> FusionM a
$cpure :: forall a. a -> FusionM a
Applicative,
      forall a b. a -> FusionM b -> FusionM a
forall a b. (a -> b) -> FusionM a -> FusionM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> FusionM b -> FusionM a
$c<$ :: forall a b. a -> FusionM b -> FusionM a
fmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
$cfmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
Functor,
      MonadState FusionEnv,
      HasScope SOACS,
      LocalScope SOACS
    )

instance MonadFreshNames FusionM where
  getNameSource :: FusionM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> VNameSource
vNameSource
  putNameSource :: VNameSource -> FusionM ()
putNameSource VNameSource
source =
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
env -> FusionEnv
env {vNameSource :: VNameSource
vNameSource = VNameSource
source})

runFusionM :: MonadFreshNames m => Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM :: forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM Scope SOACS
scope FusionEnv
fenv (FusionM ReaderT (Scope SOACS) (State FusionEnv) a
a) = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  let x :: State FusionEnv a
x = forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (State FusionEnv) a
a Scope SOACS
scope
      (a
y, FusionEnv
z) = forall s a. State s a -> s -> (a, s)
runState State FusionEnv a
x (FusionEnv
fenv {vNameSource :: VNameSource
vNameSource = VNameSource
src})
   in (a
y, FusionEnv -> VNameSource
vNameSource FusionEnv
z)

doFuseScans :: FusionM a -> FusionM a
doFuseScans :: forall a. FusionM a -> FusionM a
doFuseScans FusionM a
m = do
  Bool
fs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
True})
  a
r <- FusionM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
fs})
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r

dontFuseScans :: FusionM a -> FusionM a
dontFuseScans :: forall a. FusionM a -> FusionM a
dontFuseScans FusionM a
m = do
  Bool
fs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
False})
  a
r <- FusionM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
fs})
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r

unreachableEitherDir :: DepGraph -> G.Node -> G.Node -> Bool
unreachableEitherDir :: DepGraph -> Int -> Int -> Bool
unreachableEitherDir DepGraph
g Int
a Int
b =
  Bool -> Bool
not (DepGraph -> Int -> Int -> Bool
reachable DepGraph
g Int
a Int
b Bool -> Bool -> Bool
|| DepGraph -> Int -> Int -> Bool
reachable DepGraph
g Int
b Int
a)

isNotVarInput :: [H.Input] -> [H.Input]
isNotVarInput :: [Input] -> [Input]
isNotVarInput = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Maybe VName
H.isVarInput)

finalizeNode :: (HasScope SOACS m, MonadFreshNames m) => NodeT -> m (Stms SOACS)
finalizeNode :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode NodeT
nt = case NodeT
nt of
  StmNode Stm SOACS
stm -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  SoacNode ArrayTransforms
ots Pat Type
outputs SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
    [VName]
untransformed_outputs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
outputs
    forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
untransformed_outputs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
H.toSOAC SOAC SOACS
soac
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
outputs) [VName]
untransformed_outputs) forall a b. (a -> b) -> a -> b
$ \(VName
output, VName
v) ->
      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
H.applyTransforms ArrayTransforms
ots VName
v
  ResNode VName
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  TransNode VName
output ArrayTransform
tr VName
ia -> do
    (Certs
cs, Exp SOACS
e) <- forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
H.transformToExp ArrayTransform
tr VName
ia
    forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] Exp SOACS
e
  FreeNode VName
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
  DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
    [Stms SOACS]
lst' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
    [Stms SOACS]
lst' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm

linearizeGraph :: (HasScope SOACS m, MonadFreshNames m) => DepGraph -> m (Stms SOACS)
linearizeGraph :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
dg =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [a]
Q.topsort' (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)

fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething NodeT
x = do
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \FusionEnv
s -> FusionEnv
s {fusionCount :: Int
fusionCount = Int
1 forall a. Num a => a -> a -> a
+ FusionEnv -> Int
fusionCount FusionEnv
s}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just NodeT
x

vFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
vFusionFeasability :: DepGraph -> Int -> Int -> Bool
vFusionFeasability dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} Int
n1 Int
n2 =
  Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int, Int, EdgeT) -> Bool
isInf (DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
n1 Int
n2))
    Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (DepGraph -> Int -> Int -> Bool
reachable DepGraph
dg Int
n2) (forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= Int
n2) (forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
n1)))

hFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
hFusionFeasability :: DepGraph -> Int -> Int -> Bool
hFusionFeasability = DepGraph -> Int -> Int -> Bool
unreachableEitherDir

vTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
-- find the neighbors -> verify that fusion causes no cycles -> fuse
vTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
  | Bool -> Bool
not (forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | DepGraph -> Int -> Int -> Bool
vFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
      let (Context NodeT EdgeT
ctx1, Context NodeT EdgeT
ctx2) = (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1, forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
      Maybe (Context NodeT EdgeT)
fres <- [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable_nodes Context NodeT EdgeT
ctx1 Context NodeT EdgeT
ctx2
      case Maybe (Context NodeT EdgeT)
fres of
        Just (Adj EdgeT
inputs, Int
_, NodeT
nodeT, Adj EdgeT
outputs) -> do
          NodeT
nodeT' <-
            if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
fusedC
              then forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT
              else do
                let (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_1) = Context NodeT EdgeT
ctx1
                    (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_2) = Context NodeT EdgeT
ctx2
                    -- make copies of everything that was not
                    -- previously consumed
                    old_cons :: [VName]
old_cons = forall a b. (a -> b) -> [a] -> [b]
map (EdgeT -> VName
getName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isCons forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (Adj EdgeT
deps_1 forall a. Semigroup a => a -> a -> a
<> Adj EdgeT
deps_2)
                forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
old_cons NodeT
nodeT
          forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 (Adj EdgeT
inputs, Int
node_1, NodeT
nodeT', Adj EdgeT
outputs) DepGraph
dg
        Maybe (Context NodeT EdgeT)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  where
    edgs :: [EdgeT]
edgs = forall a b. (a -> b) -> [a] -> [b]
map forall b. LEdge b -> b
G.edgeLabel forall a b. (a -> b) -> a -> b
$ DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
node_1 Int
node_2
    fusedC :: [VName]
fusedC = forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
isCons [EdgeT]
edgs
    infusable_nodes :: [VName]
infusable_nodes =
      forall a b. (a -> b) -> [a] -> [b]
map
        (Int, Int, EdgeT) -> VName
depsFromEdge
        (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
node_1) (forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= Int
node_2) forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
node_1))

hTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
hTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
  | Bool -> Bool
not (forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | DepGraph -> Int -> Int -> Bool
hFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
      Maybe (Context NodeT EdgeT)
fres <- Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1) (forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
      case Maybe (Context NodeT EdgeT)
fres of
        Just Context NodeT EdgeT
ctx -> forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 Context NodeT EdgeT
ctx DepGraph
dg
        Maybe (Context NodeT EdgeT)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg

hFuseContexts :: DepContext -> DepContext -> FusionM (Maybe DepContext)
hFuseContexts :: Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts Context NodeT EdgeT
c1 Context NodeT EdgeT
c2 = do
  let (Adj EdgeT
_, Int
_, NodeT
nodeT1, Adj EdgeT
_) = Context NodeT EdgeT
c1
      (Adj EdgeT
_, Int
_, NodeT
nodeT2, Adj EdgeT
_) = Context NodeT EdgeT
c2
  Maybe NodeT
fres <- NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT NodeT
nodeT1 NodeT
nodeT2
  case Maybe NodeT
fres of
    Just NodeT
nodeT -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
    Maybe NodeT
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

vFuseContexts :: [EdgeT] -> [VName] -> DepContext -> DepContext -> FusionM (Maybe DepContext)
vFuseContexts :: [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable Context NodeT EdgeT
c1 Context NodeT EdgeT
c2 = do
  let (Adj EdgeT
i1, Int
n1, NodeT
nodeT1, Adj EdgeT
o1) = Context NodeT EdgeT
c1
      (Adj EdgeT
_i2, Int
n2, NodeT
nodeT2, Adj EdgeT
o2) = Context NodeT EdgeT
c2
  Maybe NodeT
fres <-
    [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT
      [EdgeT]
edgs
      [VName]
infusable
      (NodeT
nodeT1, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
(/=) Int
n2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) Adj EdgeT
i1, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst Adj EdgeT
o1)
      (NodeT
nodeT2, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
(/=) Int
n1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) Adj EdgeT
o2)
  case Maybe NodeT
fres of
    Just NodeT
nodeT -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
    Maybe NodeT
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

makeCopiesOfFusedExcept ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  NodeT ->
  m NodeT
makeCopiesOfFusedExcept :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
noCopy (SoacNode ArrayTransforms
ots Pat Type
pats SOAC SOACS
soac StmAux (ExpDec SOACS)
aux) = do
  let lam :: Lambda SOACS
lam = forall rep. SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) forall a b. (a -> b) -> a -> b
$ do
    [VName]
fused_inner <-
      forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> Bool
isAcc) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda forall a b. (a -> b) -> a -> b
$
        forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda forall a. Monoid a => a
mempty Lambda SOACS
lam
    Lambda SOACS
lam' <- forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda ([VName]
fused_inner forall a. Eq a => [a] -> [a] -> [a]
L.\\ [VName]
noCopy) Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pats (forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
makeCopiesOfFusedExcept [VName]
_ NodeT
nodeT = forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT

makeCopiesInLambda ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  Lambda SOACS ->
  m (Lambda SOACS)
makeCopiesInLambda :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda [VName]
toCopy Lambda SOACS
lam = do
  (Stms SOACS
copies, Map VName VName
nameMap) <- forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
toCopy
  let l_body :: Body SOACS
l_body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
      newBody :: Body SOACS
newBody = forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Stms SOACS
copies (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
nameMap Body SOACS
l_body)
      newLambda :: Lambda SOACS
newLambda = Lambda SOACS
lam {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
newBody}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
newLambda

makeCopyStms ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  m (Stms SOACS, M.Map VName VName)
makeCopyStms :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
vs = do
  [VName]
vs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => VName -> m VName
makeNewName [VName]
vs
  [Stm SOACS]
copies <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs') forall a b. (a -> b) -> a -> b
$ \(VName
name, VName
name') ->
    forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
mkLetNames [VName
name'] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
name
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
copies, forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs')
  where
    makeNewName :: VName -> m VName
makeNewName VName
name = forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name forall a. Semigroup a => a -> a -> a
<> String
"_copy"

okToFuseProducer :: H.SOAC SOACS -> FusionM Bool
okToFuseProducer :: SOAC SOACS -> FusionM Bool
okToFuseProducer (H.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) = do
  let is_scan :: Bool
is_scan = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form
  forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ (Bool -> Bool
not Bool
is_scan ||) forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionEnv -> Bool
fuseScans
okToFuseProducer SOAC SOACS
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- First node is producer, second is consumer.
vFuseNodeT :: [EdgeT] -> [VName] -> (NodeT, [EdgeT], [EdgeT]) -> (NodeT, [EdgeT]) -> FusionM (Maybe NodeT)
vFuseNodeT :: [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT [EdgeT]
_ [VName]
infusible (NodeT
s1, [EdgeT]
_, [EdgeT]
e1s) (MatchNode Stm SOACS
stm2 [(NodeT, [EdgeT])]
dfused, [EdgeT]
_)
  | NodeT -> Bool
isRealNode NodeT
s1,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
stm2 forall a b. (a -> b) -> a -> b
$ (NodeT
s1, [EdgeT]
e1s) forall a. a -> [a] -> [a]
: [(NodeT, [EdgeT])]
dfused
vFuseNodeT [EdgeT]
_ [VName]
infusible (TransNode VName
stm1_out ArrayTransform
tr VName
stm1_in, [EdgeT]
_, [EdgeT]
_) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_)
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible = do
      Type
stm1_in_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
stm1_in
      let onInput :: Input -> Input
onInput Input
inp
            | Input -> VName
H.inputArray Input
inp forall a. Eq a => a -> a -> Bool
== VName
stm1_out =
                ArrayTransforms -> VName -> Type -> Input
H.Input (ArrayTransform
tr ArrayTransform -> ArrayTransforms -> ArrayTransforms
H.<| Input -> ArrayTransforms
H.inputTransforms Input
inp) VName
stm1_in Type
stm1_in_t
            | Bool
otherwise =
                Input
inp
          soac2' :: SOAC SOACS
soac2' = forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
onInput (forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) forall rep. [Input] -> SOAC rep -> SOAC rep
`H.setInputs` SOAC SOACS
soac2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2' StmAux (ExpDec SOACS)
aux2
vFuseNodeT
  [EdgeT]
_
  [VName]
_
  (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1, [EdgeT]
i1s, [EdgeT]
_e1s)
  (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_e2s) = do
    let ker :: FusedSOAC
ker =
          TF.FusedSOAC
            { fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
              fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = ArrayTransforms
ots2,
              fsOutNames :: [VName]
TF.fsOutNames = forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
            }
        preserveEdge :: EdgeT -> Bool
preserveEdge InfDep {} = Bool
True
        preserveEdge EdgeT
e = EdgeT -> Bool
isDep EdgeT
e
        preserve :: Names
preserve = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
preserveEdge [EdgeT]
i1s
    Bool
ok <- SOAC SOACS -> FusionM Bool
okToFuseProducer SOAC SOACS
soac1
    Maybe FusedSOAC
r <-
      if Bool
ok Bool -> Bool -> Bool
&& ArrayTransforms
ots1 forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
        then forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
TF.attemptFusion Mode
TF.Vertical Names
preserve (forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    case Maybe FusedSOAC
r of
      Just FusedSOAC
ker' -> do
        let pats2' :: [PatElem Type]
pats2' =
              forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (forall rep. SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
        NodeT -> FusionM (Maybe NodeT)
fusedSomething forall a b. (a -> b) -> a -> b
$
          ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode
            (FusedSOAC -> ArrayTransforms
TF.fsOutputTransform FusedSOAC
ker')
            (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2')
            (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker')
            (StmAux (ExpDec SOACS)
aux1 forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec SOACS)
aux2)
      Maybe FusedSOAC
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
vFuseNodeT [EdgeT]
_ [VName]
_ (NodeT, [EdgeT], [EdgeT])
_ (NodeT, [EdgeT])
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

resFromLambda :: Lambda rep -> Result
resFromLambda :: forall rep. Lambda rep -> Result
resFromLambda = forall rep. Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody

hasNoDifferingInputs :: [H.Input] -> [H.Input] -> Bool
hasNoDifferingInputs :: [Input] -> [Input] -> Bool
hasNoDifferingInputs [Input]
is1 [Input]
is2 =
  let ([Input]
vs1, [Input]
vs2) = ([Input] -> [Input]
isNotVarInput [Input]
is1, [Input] -> [Input]
isNotVarInput forall a b. (a -> b) -> a -> b
$ [Input]
is2 forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Input]
is1)
   in forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [Input]
vs1 forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` [Input]
vs2

hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2)
  | ArrayTransforms
ots1 forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    ArrayTransforms
ots2 forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    [Input] -> [Input] -> Bool
hasNoDifferingInputs (forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac1) (forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) = do
      let ker :: FusedSOAC
ker =
            TF.FusedSOAC
              { fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
                fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = forall a. Monoid a => a
mempty,
                fsOutNames :: [VName]
TF.fsOutNames = forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
              }
          preserve :: Names
preserve = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
pats1
      Maybe FusedSOAC
r <- forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
TF.attemptFusion Mode
TF.Horizontal Names
preserve (forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
      case Maybe FusedSOAC
r of
        Just FusedSOAC
ker' -> do
          let pats2' :: [PatElem Type]
pats2' =
                forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (forall rep. SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
          NodeT -> FusionM (Maybe NodeT)
fusedSomething forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode forall a. Monoid a => a
mempty (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2') (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker') (StmAux (ExpDec SOACS)
aux1 forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec SOACS)
aux2)
        Maybe FusedSOAC
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
hFuseNodeT NodeT
_ NodeT
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
s = case NodeT
s of
  SoacNode ArrayTransforms
ots (Pat [PatElem Type]
pats1) soac :: SOAC SOACS
soac@(H.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_1 [Reduce SOACS]
red_1 Lambda SOACS
lam_1) [Input]
_) StmAux (ExpDec SOACS)
aux1 ->
    ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem Type]
pats_unchanged forall a. Semigroup a => a -> a -> a
<> [PatElem Type]
pats_new) (forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam_new SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux1
    where
      scan_output_size :: Int
scan_output_size = forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_1
      red_output_size :: Int
red_output_size = forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
red_1

      ([PatElem Type]
pats_unchanged, [PatElem Type]
pats_toChange) = forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size forall a. Num a => a -> a -> a
+ Int
red_output_size) [PatElem Type]
pats1
      ([(SubExpRes, Type)]
res_unchanged, [(SubExpRes, Type)]
res_toChange) = forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size forall a. Num a => a -> a -> a
+ Int
red_output_size) (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> Result
resFromLambda Lambda SOACS
lam_1) (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam_1))

      ([PatElem Type]
pats_new, [(SubExpRes, Type)]
other) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem Type
x, (SubExpRes, Type)
_) -> forall dec. PatElem dec -> VName
patElemName PatElem Type
x forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
toKeep) (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pats_toChange [(SubExpRes, Type)]
res_toChange)
      (Result
results, [Type]
types) = forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, Type)]
res_unchanged forall a. [a] -> [a] -> [a]
++ [(SubExpRes, Type)]
other)
      lam_new :: Lambda SOACS
lam_new =
        Lambda SOACS
lam_1
          { lambdaReturnType :: [Type]
lambdaReturnType = [Type]
types,
            lambdaBody :: Body SOACS
lambdaBody = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_1) {bodyResult :: Result
bodyResult = Result
results}
          }
  NodeT
node -> NodeT
node

vNameFromAdj :: G.Node -> (EdgeT, G.Node) -> VName
vNameFromAdj :: Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1 (EdgeT
edge, Int
n2) = (Int, Int, EdgeT) -> VName
depsFromEdge (Int
n2, Int
n1, EdgeT
edge)

removeUnusedOutputsFromContext :: DepContext -> FusionM DepContext
removeUnusedOutputsFromContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext (Adj EdgeT
incoming, Int
n1, NodeT
nodeT, Adj EdgeT
outgoing) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
n1, NodeT
nodeT', Adj EdgeT
outgoing)
  where
    toKeep :: [VName]
toKeep = forall a b. (a -> b) -> [a] -> [b]
map (Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1) Adj EdgeT
incoming
    nodeT' :: NodeT
nodeT' = [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
nodeT

removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs = forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext

tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph DepNode
node_to_fuse dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} = do
  if forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_to_fuse_id Gr NodeT EdgeT
g -- Node might have been fused away since.
    then forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs (forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_to_fuse_id) [Int]
fuses_with) DepGraph
dg
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  where
    fuses_with :: [Int]
fuses_with = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isDep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g (DepNode -> Int
nodeFromLNode DepNode
node_to_fuse)
    node_to_fuse_id :: Int
node_to_fuse_id = DepNode -> Int
nodeFromLNode DepNode
node_to_fuse

doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion DepGraph
dg = forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs (forall a b. (a -> b) -> [a] -> [b]
map DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall {a}. (a, NodeT) -> Bool
relevant forall a b. (a -> b) -> a -> b
$ forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) DepGraph
dg
  where
    relevant :: (a, NodeT) -> Bool
relevant (a
_, StmNode {}) = Bool
False
    relevant (a
_, ResNode {}) = Bool
False
    relevant (a, NodeT)
_ = Bool
True

-- | For each pair of SOAC nodes that share an input, attempt to fuse
-- them horizontally.
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion DepGraph
dg = forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug FusionM]
pairs DepGraph
dg
  where
    pairs :: [DepGraphAug FusionM]
    pairs :: [DepGraphAug FusionM]
pairs = do
      (Int
x, SoacNode ArrayTransforms
_ Pat Type
_ SOAC SOACS
soac_x StmAux (ExpDec SOACS)
_) <- forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
      (Int
y, SoacNode ArrayTransforms
_ Pat Type
_ SOAC SOACS
soac_y StmAux (ExpDec SOACS)
_) <- forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Int
x forall a. Ord a => a -> a -> Bool
< Int
y
      -- Must share an input.
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
          ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
H.inputArray (forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac_x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
H.inputArray)
          (forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac_y)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ \DepGraph
dg' -> do
        -- Nodes might have been fused away by now.
        if forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
x (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg') Bool -> Bool -> Bool
&& forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
y (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg')
          then Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
x Int
y DepGraph
dg'
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg'

doInnerFusion :: DepGraphAug FusionM
doInnerFusion :: DepGraphAug FusionM
doInnerFusion = forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext

-- Fixed-point iteration.
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g = do
  Int
prev_fused <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
  DepGraph
g' <- DepGraphAug FusionM
f DepGraph
g
  Int
aft_fused <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
  if Int
prev_fused forall a. Eq a => a -> a -> Bool
/= Int
aft_fused then DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g' else forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
g'

doAllFusion :: DepGraphAug FusionM
doAllFusion :: DepGraphAug FusionM
doAllFusion =
  forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs
    [ DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs forall a b. (a -> b) -> a -> b
$
        [ DepGraphAug FusionM
doVerticalFusion,
          DepGraphAug FusionM
doHorizontalFusion,
          DepGraphAug FusionM
doInnerFusion
        ],
      DepGraphAug FusionM
removeUnusedOutputs
    ]

runInnerFusionOnContext :: DepContext -> FusionM DepContext
runInnerFusionOnContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext c :: Context NodeT EdgeT
c@(Adj EdgeT
incoming, Int
node, NodeT
nodeT, Adj EdgeT
outgoing) = case NodeT
nodeT of
  DoNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
body)) [(NodeT, [EdgeT])]
to_fuse ->
    forall a. FusionM a -> FusionM a
doFuseScans forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
params) forall a. Semigroup a => a -> a -> a
<> forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
form) forall a b. (a -> b) -> a -> b
$ do
      Body SOACS
b <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
body [(NodeT, [EdgeT])]
to_fuse
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
b)) [], Adj EdgeT
outgoing)
  MatchNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
dec)) [(NodeT, [EdgeT])]
to_fuse -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    [Case (Body SOACS)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
`doFusionWithDelayed` [(NodeT, [EdgeT])]
to_fuse)) [Case (Body SOACS)]
cases
    Body SOACS
defbody' <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
defbody [(NodeT, [EdgeT])]
to_fuse
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body SOACS)]
cases' Body SOACS
defbody' MatchDec (BranchType SOACS)
dec)) [], Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.VJP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall rep. Op rep -> Exp rep
Op (forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
Futhark.VJP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec))), Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.JVP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall rep. Op rep -> Exp rep
Op (forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
Futhark.JVP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec))), Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) -> forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam')), Adj EdgeT
outgoing)
  SoacNode ArrayTransforms
ots Pat Type
pat SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> do
    let lam :: Lambda SOACS
lam = forall rep. SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
    Lambda SOACS
lam' <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) forall a b. (a -> b) -> a -> b
$ case SOAC SOACS
soac of
      H.Stream {} ->
        forall a. FusionM a -> FusionM a
dontFuseScans forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
      SOAC SOACS
_ ->
        forall a. FusionM a -> FusionM a
doFuseScans forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    let nodeT' :: NodeT
nodeT' = ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pat (forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, NodeT
nodeT', Adj EdgeT
outgoing)
  NodeT
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Context NodeT EdgeT
c
  where
    doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
    doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed (Body () Stms SOACS
stms Result
res) [(NodeT, [EdgeT])]
extraNodes = forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms) forall a b. (a -> b) -> a -> b
$ do
      [Stms SOACS]
stm_node <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
extraNodes
      Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph (forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
stm_node forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms) Result
res)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' Result
res
    doFusionBody :: Body SOACS -> FusionM (Body SOACS)
    doFusionBody :: Body SOACS -> FusionM (Body SOACS)
doFusionBody Body SOACS
body = do
      Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms'}
    doFusionLambda :: Lambda SOACS -> FusionM (Lambda SOACS)
    doFusionLambda :: Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam = do
      -- To clean up previous instances of fusion.
      Lambda SOACS
lam' <- forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
      Int
prev_count <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
      Body SOACS
newbody <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam') forall a b. (a -> b) -> a -> b
$ Body SOACS -> FusionM (Body SOACS)
doFusionBody forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
      Int
aft_count <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
      -- To clean up any inner fusion.
      (if Int
prev_count forall a. Eq a => a -> a -> Bool
/= Int
aft_count then forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda else forall (f :: * -> *) a. Applicative f => a -> f a
pure)
        Lambda SOACS
lam' {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
newbody}

-- main fusion function.
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body = forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf (forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body)) forall a b. (a -> b) -> a -> b
$ do
  DepGraph
graph_not_fused <- forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body
  DepGraph
graph_fused <- DepGraphAug FusionM
doAllFusion DepGraph
graph_not_fused
  forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
graph_fused

fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts [VName]
outputs Stms SOACS
stms =
  forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
    (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms)
    FusionEnv
freshFusionEnv
    (Body SOACS -> FusionM (Stms SOACS)
fuseGraph (forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms ([VName] -> Result
varsRes [VName]
outputs)))

fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
  Stms SOACS
fun_stms' <-
    forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
      (forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fun forall a. Semigroup a => a -> a -> a
<> forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts)
      FusionEnv
freshFusionEnv
      (Body SOACS -> FusionM (Stms SOACS)
fuseGraph (forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fun))
  forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef SOACS
fun {funDefBody :: Body SOACS
funDefBody = (forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
fun_stms'}}

-- | The pass definition.
{-# NOINLINE fuseSOACs #-}
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
  Pass
    { passName :: String
passName = String
"Fuse SOACs",
      passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
p ->
        forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          ([VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn (forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
p)))
          Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
          Prog SOACS
p
    }