{-# LANGUAGE Strict #-}

-- | Perform horizontal and vertical fusion of SOACs.  See the paper
-- /A T2 Graph-Reduction Approach To Fusion/ for the basic idea (some
-- extensions discussed in /Design and GPGPU Performance of Futhark’s
-- Redomap Construct/).
module Futhark.Optimise.Fusion (fuseSOACs) where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Graph.Inductive.Graph qualified as G
import Data.Graph.Inductive.Query.DFS qualified as Q
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.HORep.SOAC qualified as H
import Futhark.Construct
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Optimise.Fusion.GraphRep
import Futhark.Optimise.Fusion.TryFusion qualified as TF
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute

data FusionEnv = FusionEnv
  { FusionEnv -> VNameSource
vNameSource :: VNameSource,
    FusionEnv -> Int
fusionCount :: Int,
    FusionEnv -> Bool
fuseScans :: Bool
  }

freshFusionEnv :: FusionEnv
freshFusionEnv :: FusionEnv
freshFusionEnv =
  FusionEnv
    { vNameSource :: VNameSource
vNameSource = VNameSource
blankNameSource,
      fusionCount :: Int
fusionCount = Int
0,
      fuseScans :: Bool
fuseScans = Bool
True
    }

newtype FusionM a = FusionM (ReaderT (Scope SOACS) (State FusionEnv) a)
  deriving
    ( Applicative FusionM
Applicative FusionM
-> (forall a b. FusionM a -> (a -> FusionM b) -> FusionM b)
-> (forall a b. FusionM a -> FusionM b -> FusionM b)
-> (forall a. a -> FusionM a)
-> Monad FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
>>= :: forall a b. FusionM a -> (a -> FusionM b) -> FusionM b
$c>> :: forall a b. FusionM a -> FusionM b -> FusionM b
>> :: forall a b. FusionM a -> FusionM b -> FusionM b
$creturn :: forall a. a -> FusionM a
return :: forall a. a -> FusionM a
Monad,
      Functor FusionM
Functor FusionM
-> (forall a. a -> FusionM a)
-> (forall a b. FusionM (a -> b) -> FusionM a -> FusionM b)
-> (forall a b c.
    (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c)
-> (forall a b. FusionM a -> FusionM b -> FusionM b)
-> (forall a b. FusionM a -> FusionM b -> FusionM a)
-> Applicative FusionM
forall a. a -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM a
forall a b. FusionM a -> FusionM b -> FusionM b
forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> FusionM a
pure :: forall a. a -> FusionM a
$c<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
<*> :: forall a b. FusionM (a -> b) -> FusionM a -> FusionM b
$cliftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
liftA2 :: forall a b c. (a -> b -> c) -> FusionM a -> FusionM b -> FusionM c
$c*> :: forall a b. FusionM a -> FusionM b -> FusionM b
*> :: forall a b. FusionM a -> FusionM b -> FusionM b
$c<* :: forall a b. FusionM a -> FusionM b -> FusionM a
<* :: forall a b. FusionM a -> FusionM b -> FusionM a
Applicative,
      (forall a b. (a -> b) -> FusionM a -> FusionM b)
-> (forall a b. a -> FusionM b -> FusionM a) -> Functor FusionM
forall a b. a -> FusionM b -> FusionM a
forall a b. (a -> b) -> FusionM a -> FusionM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
fmap :: forall a b. (a -> b) -> FusionM a -> FusionM b
$c<$ :: forall a b. a -> FusionM b -> FusionM a
<$ :: forall a b. a -> FusionM b -> FusionM a
Functor,
      MonadState FusionEnv,
      HasScope SOACS,
      LocalScope SOACS
    )

instance MonadFreshNames FusionM where
  getNameSource :: FusionM VNameSource
getNameSource = (FusionEnv -> VNameSource) -> FusionM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> VNameSource
vNameSource
  putNameSource :: VNameSource -> FusionM ()
putNameSource VNameSource
source =
    (FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
env -> FusionEnv
env {vNameSource :: VNameSource
vNameSource = VNameSource
source})

runFusionM :: (MonadFreshNames m) => Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM :: forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM Scope SOACS
scope FusionEnv
fenv (FusionM ReaderT (Scope SOACS) (State FusionEnv) a
a) = (VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  let x :: State FusionEnv a
x = ReaderT (Scope SOACS) (State FusionEnv) a
-> Scope SOACS -> State FusionEnv a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (State FusionEnv) a
a Scope SOACS
scope
      (a
y, FusionEnv
z) = State FusionEnv a -> FusionEnv -> (a, FusionEnv)
forall s a. State s a -> s -> (a, s)
runState State FusionEnv a
x (FusionEnv
fenv {vNameSource :: VNameSource
vNameSource = VNameSource
src})
   in (a
y, FusionEnv -> VNameSource
vNameSource FusionEnv
z)

doFuseScans :: FusionM a -> FusionM a
doFuseScans :: forall a. FusionM a -> FusionM a
doFuseScans FusionM a
m = do
  Bool
fs <- (FusionEnv -> Bool) -> FusionM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
  (FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
True})
  a
r <- FusionM a
m
  (FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
fs})
  a -> FusionM a
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r

dontFuseScans :: FusionM a -> FusionM a
dontFuseScans :: forall a. FusionM a -> FusionM a
dontFuseScans FusionM a
m = do
  Bool
fs <- (FusionEnv -> Bool) -> FusionM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Bool
fuseScans
  (FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
False})
  a
r <- FusionM a
m
  (FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\FusionEnv
s -> FusionEnv
s {fuseScans :: Bool
fuseScans = Bool
fs})
  a -> FusionM a
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r

unreachableEitherDir :: DepGraph -> G.Node -> G.Node -> Bool
unreachableEitherDir :: DepGraph -> Int -> Int -> Bool
unreachableEitherDir DepGraph
g Int
a Int
b =
  Bool -> Bool
not (DepGraph -> Int -> Int -> Bool
reachable DepGraph
g Int
a Int
b Bool -> Bool -> Bool
|| DepGraph -> Int -> Int -> Bool
reachable DepGraph
g Int
b Int
a)

isNotVarInput :: [H.Input] -> [H.Input]
isNotVarInput :: [Input] -> [Input]
isNotVarInput = (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe VName -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe VName -> Bool) -> (Input -> Maybe VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Maybe VName
H.isVarInput)

finalizeNode :: (HasScope SOACS m, MonadFreshNames m) => NodeT -> m (Stms SOACS)
finalizeNode :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode NodeT
nt = case NodeT
nt of
  StmNode Stm SOACS
stm -> Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  SoacNode ArrayTransforms
ots Pat Type
outputs SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> Builder SOACS () -> m (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder SOACS () -> m (Stms SOACS))
-> Builder SOACS () -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
    [VName]
untransformed_outputs <- (VName -> BuilderT SOACS (State VNameSource) VName)
-> [VName] -> BuilderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName ([VName] -> BuilderT SOACS (State VNameSource) [VName])
-> [VName] -> BuilderT SOACS (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
outputs
    StmAux () -> Builder SOACS () -> Builder SOACS ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (Builder SOACS () -> Builder SOACS ())
-> Builder SOACS () -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
untransformed_outputs (Exp SOACS -> Builder SOACS ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> Builder SOACS ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Builder SOACS ())
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
-> Builder SOACS ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
     SOACS
     (State VNameSource)
     (SOAC (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
H.toSOAC SOAC (Rep (BuilderT SOACS (State VNameSource)))
SOAC SOACS
soac
    [(VName, VName)]
-> ((VName, VName) -> Builder SOACS ()) -> Builder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
outputs) [VName]
untransformed_outputs) (((VName, VName) -> Builder SOACS ()) -> Builder SOACS ())
-> ((VName, VName) -> Builder SOACS ()) -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
output, VName
v) ->
      [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] (Exp SOACS -> Builder SOACS ())
-> (VName -> Exp SOACS) -> VName -> Builder SOACS ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> Builder SOACS ())
-> BuilderT SOACS (State VNameSource) VName -> Builder SOACS ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ArrayTransforms
-> VName -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
H.applyTransforms ArrayTransforms
ots VName
v
  ResNode VName
_ -> Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty
  TransNode VName
output ArrayTransform
tr VName
ia -> do
    (Certs
cs, Exp SOACS
e) <- ArrayTransform -> VName -> m (Certs, Exp SOACS)
forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
H.transformToExp ArrayTransform
tr VName
ia
    Builder SOACS () -> m (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder SOACS () -> m (Stms SOACS))
-> Builder SOACS () -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Certs -> Builder SOACS () -> Builder SOACS ()
forall a.
Certs
-> BuilderT SOACS (State VNameSource) a
-> BuilderT SOACS (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (Builder SOACS () -> Builder SOACS ())
-> Builder SOACS () -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
output] Exp (Rep (BuilderT SOACS (State VNameSource)))
Exp SOACS
e
  FreeNode VName
_ -> Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty
  DoNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
    [Stms SOACS]
lst' <- ((NodeT, [EdgeT]) -> m (Stms SOACS))
-> [(NodeT, [EdgeT])] -> m [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (NodeT -> m (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode (NodeT -> m (Stms SOACS))
-> ((NodeT, [EdgeT]) -> NodeT)
-> (NodeT, [EdgeT])
-> m (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT, [EdgeT]) -> NodeT
forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
    Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  MatchNode Stm SOACS
stm [(NodeT, [EdgeT])]
lst -> do
    [Stms SOACS]
lst' <- ((NodeT, [EdgeT]) -> m (Stms SOACS))
-> [(NodeT, [EdgeT])] -> m [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (NodeT -> m (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode (NodeT -> m (Stms SOACS))
-> ((NodeT, [EdgeT]) -> NodeT)
-> (NodeT, [EdgeT])
-> m (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT, [EdgeT]) -> NodeT
forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
lst
    Stms SOACS -> m (Stms SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
lst' Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm

linearizeGraph :: (HasScope SOACS m, MonadFreshNames m) => DepGraph -> m (Stms SOACS)
linearizeGraph :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
dg =
  ([Stms SOACS] -> Stms SOACS) -> m [Stms SOACS] -> m (Stms SOACS)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat (m [Stms SOACS] -> m (Stms SOACS))
-> m [Stms SOACS] -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (NodeT -> m (Stms SOACS)) -> [NodeT] -> m [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM NodeT -> m (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode ([NodeT] -> m [Stms SOACS]) -> [NodeT] -> m [Stms SOACS]
forall a b. (a -> b) -> a -> b
$ [NodeT] -> [NodeT]
forall a. [a] -> [a]
reverse ([NodeT] -> [NodeT]) -> [NodeT] -> [NodeT]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> [NodeT]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [a]
Q.topsort' (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)

fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething :: NodeT -> FusionM (Maybe NodeT)
fusedSomething NodeT
x = do
  (FusionEnv -> FusionEnv) -> FusionM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((FusionEnv -> FusionEnv) -> FusionM ())
-> (FusionEnv -> FusionEnv) -> FusionM ()
forall a b. (a -> b) -> a -> b
$ \FusionEnv
s -> FusionEnv
s {fusionCount :: Int
fusionCount = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ FusionEnv -> Int
fusionCount FusionEnv
s}
  Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe NodeT -> FusionM (Maybe NodeT))
-> Maybe NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ NodeT -> Maybe NodeT
forall a. a -> Maybe a
Just NodeT
x

vFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
vFusionFeasability :: DepGraph -> Int -> Int -> Bool
vFusionFeasability dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} Int
n1 Int
n2 =
  Bool -> Bool
not (((Int, Int, EdgeT) -> Bool) -> [(Int, Int, EdgeT)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int, Int, EdgeT) -> Bool
isInf (DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
n1 Int
n2))
    Bool -> Bool -> Bool
&& Bool -> Bool
not ((Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (DepGraph -> Int -> Int -> Bool
reachable DepGraph
dg Int
n2) ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n2) (Gr NodeT EdgeT -> Int -> [Int]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
n1)))

hFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool
hFusionFeasability :: DepGraph -> Int -> Int -> Bool
hFusionFeasability = DepGraph -> Int -> Int -> Bool
unreachableEitherDir

vTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
-- find the neighbors -> verify that fusion causes no cycles -> fuse
vTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
  | Bool -> Bool
not (Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | DepGraph -> Int -> Int -> Bool
vFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
      let (Context NodeT EdgeT
ctx1, Context NodeT EdgeT
ctx2) = (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1, Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
      Maybe (Context NodeT EdgeT)
fres <- [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable_nodes Context NodeT EdgeT
ctx1 Context NodeT EdgeT
ctx2
      case Maybe (Context NodeT EdgeT)
fres of
        Just (Adj EdgeT
inputs, Int
_, NodeT
nodeT, Adj EdgeT
outputs) -> do
          NodeT
nodeT' <-
            if [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
fusedC
              then NodeT -> FusionM NodeT
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT
              else do
                let (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_1) = Context NodeT EdgeT
ctx1
                    (Adj EdgeT
_, Int
_, NodeT
_, Adj EdgeT
deps_2) = Context NodeT EdgeT
ctx2
                    -- make copies of everything that was not
                    -- previously consumed
                    old_cons :: [VName]
old_cons = ((EdgeT, Int) -> VName) -> Adj EdgeT -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT -> VName
getName (EdgeT -> VName)
-> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) (Adj EdgeT -> [VName]) -> Adj EdgeT -> [VName]
forall a b. (a -> b) -> a -> b
$ ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isCons (EdgeT -> Bool) -> ((EdgeT, Int) -> EdgeT) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst) (Adj EdgeT
deps_1 Adj EdgeT -> Adj EdgeT -> Adj EdgeT
forall a. Semigroup a => a -> a -> a
<> Adj EdgeT
deps_2)
                [VName] -> NodeT -> FusionM NodeT
forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
old_cons NodeT
nodeT
          Int -> Context NodeT EdgeT -> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 (Adj EdgeT
inputs, Int
node_1, NodeT
nodeT', Adj EdgeT
outputs) DepGraph
dg
        Maybe (Context NodeT EdgeT)
Nothing -> DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | Bool
otherwise = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  where
    edgs :: [EdgeT]
edgs = ((Int, Int, EdgeT) -> EdgeT) -> [(Int, Int, EdgeT)] -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int, EdgeT) -> EdgeT
forall b. LEdge b -> b
G.edgeLabel ([(Int, Int, EdgeT)] -> [EdgeT]) -> [(Int, Int, EdgeT)] -> [EdgeT]
forall a b. (a -> b) -> a -> b
$ DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
node_1 Int
node_2
    fusedC :: [VName]
fusedC = (EdgeT -> VName) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName ([EdgeT] -> [VName]) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> a -> b
$ (EdgeT -> Bool) -> [EdgeT] -> [EdgeT]
forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
isCons [EdgeT]
edgs
    infusable_nodes :: [VName]
infusable_nodes =
      ((Int, Int, EdgeT) -> VName) -> [(Int, Int, EdgeT)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map
        (Int, Int, EdgeT) -> VName
depsFromEdge
        ((Int -> [(Int, Int, EdgeT)]) -> [Int] -> [(Int, Int, EdgeT)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DepGraph -> Int -> Int -> [(Int, Int, EdgeT)]
edgesBetween DepGraph
dg Int
node_1) ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
node_2) ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> Int -> [Int]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> Int -> [Int]
G.pre Gr NodeT EdgeT
g Int
node_1))

hTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM
hTryFuseNodesInGraph :: Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
node_1 Int
node_2 dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g}
  | Bool -> Bool
not (Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_1 Gr NodeT EdgeT
g Bool -> Bool -> Bool
&& Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_2 Gr NodeT EdgeT
g) = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | DepGraph -> Int -> Int -> Bool
hFusionFeasability DepGraph
dg Int
node_1 Int
node_2 = do
      Maybe (Context NodeT EdgeT)
fres <- Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_1) (Gr NodeT EdgeT -> Int -> Context NodeT EdgeT
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> Context a b
G.context Gr NodeT EdgeT
g Int
node_2)
      case Maybe (Context NodeT EdgeT)
fres of
        Just Context NodeT EdgeT
ctx -> Int -> Context NodeT EdgeT -> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
Int -> Context NodeT EdgeT -> DepGraphAug m
contractEdge Int
node_2 Context NodeT EdgeT
ctx DepGraph
dg
        Maybe (Context NodeT EdgeT)
Nothing -> DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  | Bool
otherwise = DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg

hFuseContexts :: DepContext -> DepContext -> FusionM (Maybe DepContext)
hFuseContexts :: Context NodeT EdgeT
-> Context NodeT EdgeT -> FusionM (Maybe (Context NodeT EdgeT))
hFuseContexts Context NodeT EdgeT
c1 Context NodeT EdgeT
c2 = do
  let (Adj EdgeT
_, Int
_, NodeT
nodeT1, Adj EdgeT
_) = Context NodeT EdgeT
c1
      (Adj EdgeT
_, Int
_, NodeT
nodeT2, Adj EdgeT
_) = Context NodeT EdgeT
c2
  Maybe NodeT
fres <- NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT NodeT
nodeT1 NodeT
nodeT2
  case Maybe NodeT
fres of
    Just NodeT
nodeT -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Context NodeT EdgeT)
 -> FusionM (Maybe (Context NodeT EdgeT)))
-> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a b. (a -> b) -> a -> b
$ Context NodeT EdgeT -> Maybe (Context NodeT EdgeT)
forall a. a -> Maybe a
Just (NodeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
    Maybe NodeT
Nothing -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Context NodeT EdgeT)
forall a. Maybe a
Nothing

vFuseContexts :: [EdgeT] -> [VName] -> DepContext -> DepContext -> FusionM (Maybe DepContext)
vFuseContexts :: [EdgeT]
-> [VName]
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> FusionM (Maybe (Context NodeT EdgeT))
vFuseContexts [EdgeT]
edgs [VName]
infusable Context NodeT EdgeT
c1 Context NodeT EdgeT
c2 = do
  let (Adj EdgeT
i1, Int
n1, NodeT
nodeT1, Adj EdgeT
o1) = Context NodeT EdgeT
c1
      (Adj EdgeT
_i2, Int
n2, NodeT
nodeT2, Adj EdgeT
o2) = Context NodeT EdgeT
c2
  Maybe NodeT
fres <-
    [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT
      [EdgeT]
edgs
      [VName]
infusable
      (NodeT
nodeT1, ((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst (Adj EdgeT -> [EdgeT]) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> a -> b
$ ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(/=) Int
n2 (Int -> Bool) -> ((EdgeT, Int) -> Int) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> Int
forall a b. (a, b) -> b
snd) Adj EdgeT
i1, ((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst Adj EdgeT
o1)
      (NodeT
nodeT2, ((EdgeT, Int) -> EdgeT) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeT, Int) -> EdgeT
forall a b. (a, b) -> a
fst (Adj EdgeT -> [EdgeT]) -> Adj EdgeT -> [EdgeT]
forall a b. (a -> b) -> a -> b
$ ((EdgeT, Int) -> Bool) -> Adj EdgeT -> Adj EdgeT
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(/=) Int
n1 (Int -> Bool) -> ((EdgeT, Int) -> Int) -> (EdgeT, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeT, Int) -> Int
forall a b. (a, b) -> b
snd) Adj EdgeT
o2)
  case Maybe NodeT
fres of
    Just NodeT
nodeT -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Context NodeT EdgeT)
 -> FusionM (Maybe (Context NodeT EdgeT)))
-> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a b. (a -> b) -> a -> b
$ Context NodeT EdgeT -> Maybe (Context NodeT EdgeT)
forall a. a -> Maybe a
Just (NodeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
-> Context NodeT EdgeT
forall b a. Ord b => a -> Context a b -> Context a b -> Context a b
mergedContext NodeT
nodeT Context NodeT EdgeT
c1 Context NodeT EdgeT
c2)
    Maybe NodeT
Nothing -> Maybe (Context NodeT EdgeT)
-> FusionM (Maybe (Context NodeT EdgeT))
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Context NodeT EdgeT)
forall a. Maybe a
Nothing

makeCopiesOfFusedExcept ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  NodeT ->
  m NodeT
makeCopiesOfFusedExcept :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> NodeT -> m NodeT
makeCopiesOfFusedExcept [VName]
noCopy (SoacNode ArrayTransforms
ots Pat Type
pats SOAC SOACS
soac StmAux (ExpDec SOACS)
aux) = do
  let lam :: Lambda SOACS
lam = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
  Scope SOACS -> m NodeT -> m NodeT
forall a. Scope SOACS -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) (m NodeT -> m NodeT) -> m NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ do
    [VName]
fused_inner <-
      (VName -> m Bool) -> [VName] -> m [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ((Type -> Bool) -> m Type -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool -> Bool
not (Bool -> Bool) -> (Type -> Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc) (m Type -> m Bool) -> (VName -> m Type) -> VName -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) ([VName] -> m [VName])
-> (Lambda (Aliases SOACS) -> [VName])
-> Lambda (Aliases SOACS)
-> m [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> (Lambda (Aliases SOACS) -> Names)
-> Lambda (Aliases SOACS)
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Aliases SOACS) -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda (Lambda (Aliases SOACS) -> m [VName])
-> Lambda (Aliases SOACS) -> m [VName]
forall a b. (a -> b) -> a -> b
$
        AliasTable -> Lambda SOACS -> Lambda (Aliases SOACS)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
forall a. Monoid a => a
mempty Lambda SOACS
lam
    Lambda SOACS
lam' <- [VName] -> Lambda SOACS -> m (Lambda SOACS)
forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda ([VName]
fused_inner [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [VName]
noCopy) Lambda SOACS
lam
    NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeT -> m NodeT) -> NodeT -> m NodeT
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pats (Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
makeCopiesOfFusedExcept [VName]
_ NodeT
nodeT = NodeT -> m NodeT
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NodeT
nodeT

makeCopiesInLambda ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  Lambda SOACS ->
  m (Lambda SOACS)
makeCopiesInLambda :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> Lambda SOACS -> m (Lambda SOACS)
makeCopiesInLambda [VName]
toCopy Lambda SOACS
lam = do
  (Stms SOACS
copies, Map VName VName
nameMap) <- [VName] -> m (Stms SOACS, Map VName VName)
forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
toCopy
  let l_body :: Body SOACS
l_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
      newBody :: Body SOACS
newBody = Stms SOACS -> Body SOACS -> Body SOACS
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Stms SOACS
copies (Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
nameMap Body SOACS
l_body)
      newLambda :: Lambda SOACS
newLambda = Lambda SOACS
lam {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
newBody}
  Lambda SOACS -> m (Lambda SOACS)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
newLambda

makeCopyStms ::
  (LocalScope SOACS m, MonadFreshNames m) =>
  [VName] ->
  m (Stms SOACS, M.Map VName VName)
makeCopyStms :: forall (m :: * -> *).
(LocalScope SOACS m, MonadFreshNames m) =>
[VName] -> m (Stms SOACS, Map VName VName)
makeCopyStms [VName]
vs = do
  [VName]
vs' <- (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
makeNewName [VName]
vs
  [Stm SOACS]
copies <- [(VName, VName)]
-> ((VName, VName) -> m (Stm SOACS)) -> m [Stm SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs') (((VName, VName) -> m (Stm SOACS)) -> m [Stm SOACS])
-> ((VName, VName) -> m (Stm SOACS)) -> m [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ \(VName
name, VName
name') ->
    [VName] -> Exp SOACS -> m (Stm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
[VName] -> Exp SOACS -> m (Stm SOACS)
mkLetNames [VName
name'] (Exp SOACS -> m (Stm SOACS)) -> Exp SOACS -> m (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
  (Stms SOACS, Map VName VName) -> m (Stms SOACS, Map VName VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm SOACS] -> Stms SOACS
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
copies, [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vs [VName]
vs')
  where
    makeNewName :: VName -> m VName
makeNewName VName
name = String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy"

okToFuseProducer :: H.SOAC SOACS -> FusionM Bool
okToFuseProducer :: SOAC SOACS -> FusionM Bool
okToFuseProducer (H.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) = do
  let is_scan :: Bool
is_scan = Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe ([Scan SOACS], Lambda SOACS) -> Bool)
-> Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form
  (FusionEnv -> Bool) -> FusionM Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((FusionEnv -> Bool) -> FusionM Bool)
-> (FusionEnv -> Bool) -> FusionM Bool
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool
not Bool
is_scan ||) (Bool -> Bool) -> (FusionEnv -> Bool) -> FusionEnv -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionEnv -> Bool
fuseScans
okToFuseProducer SOAC SOACS
_ = Bool -> FusionM Bool
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- First node is producer, second is consumer.
vFuseNodeT :: [EdgeT] -> [VName] -> (NodeT, [EdgeT], [EdgeT]) -> (NodeT, [EdgeT]) -> FusionM (Maybe NodeT)
vFuseNodeT :: [EdgeT]
-> [VName]
-> (NodeT, [EdgeT], [EdgeT])
-> (NodeT, [EdgeT])
-> FusionM (Maybe NodeT)
vFuseNodeT [EdgeT]
_ [VName]
infusible (NodeT
s1, [EdgeT]
_, [EdgeT]
e1s) (MatchNode Stm SOACS
stm2 [(NodeT, [EdgeT])]
dfused, [EdgeT]
_)
  | NodeT -> Bool
isRealNode NodeT
s1,
    [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible =
      Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe NodeT -> FusionM (Maybe NodeT))
-> Maybe NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ NodeT -> Maybe NodeT
forall a. a -> Maybe a
Just (NodeT -> Maybe NodeT) -> NodeT -> Maybe NodeT
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode Stm SOACS
stm2 ([(NodeT, [EdgeT])] -> NodeT) -> [(NodeT, [EdgeT])] -> NodeT
forall a b. (a -> b) -> a -> b
$ (NodeT
s1, [EdgeT]
e1s) (NodeT, [EdgeT]) -> [(NodeT, [EdgeT])] -> [(NodeT, [EdgeT])]
forall a. a -> [a] -> [a]
: [(NodeT, [EdgeT])]
dfused
vFuseNodeT [EdgeT]
_ [VName]
infusible (TransNode VName
stm1_out ArrayTransform
tr VName
stm1_in, [EdgeT]
_, [EdgeT]
_) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_)
  | [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
infusible = do
      Type
stm1_in_t <- VName -> FusionM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
stm1_in
      let onInput :: Input -> Input
onInput Input
inp
            | Input -> VName
H.inputArray Input
inp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
stm1_out =
                ArrayTransforms -> VName -> Type -> Input
H.Input (ArrayTransform
tr ArrayTransform -> ArrayTransforms -> ArrayTransforms
H.<| Input -> ArrayTransforms
H.inputTransforms Input
inp) VName
stm1_in Type
stm1_in_t
            | Bool
otherwise =
                Input
inp
          soac2' :: SOAC SOACS
soac2' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
onInput (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`H.setInputs` SOAC SOACS
soac2
      Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe NodeT -> FusionM (Maybe NodeT))
-> Maybe NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ NodeT -> Maybe NodeT
forall a. a -> Maybe a
Just (NodeT -> Maybe NodeT) -> NodeT -> Maybe NodeT
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2' StmAux (ExpDec SOACS)
aux2
vFuseNodeT
  [EdgeT]
_
  [VName]
_
  (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1, [EdgeT]
i1s, [EdgeT]
_e1s)
  (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2, [EdgeT]
_e2s) = do
    let ker :: FusedSOAC
ker =
          TF.FusedSOAC
            { fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
              fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = ArrayTransforms
ots2,
              fsOutNames :: [VName]
TF.fsOutNames = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
            }
        preserveEdge :: EdgeT -> Bool
preserveEdge InfDep {} = Bool
True
        preserveEdge EdgeT
e = EdgeT -> Bool
isDep EdgeT
e
        preserve :: Names
preserve = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (EdgeT -> VName) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map EdgeT -> VName
getName ([EdgeT] -> [VName]) -> [EdgeT] -> [VName]
forall a b. (a -> b) -> a -> b
$ (EdgeT -> Bool) -> [EdgeT] -> [EdgeT]
forall a. (a -> Bool) -> [a] -> [a]
filter EdgeT -> Bool
preserveEdge [EdgeT]
i1s
    Bool
ok <- SOAC SOACS -> FusionM Bool
okToFuseProducer SOAC SOACS
soac1
    Maybe FusedSOAC
r <-
      if Bool
ok Bool -> Bool -> Bool
&& ArrayTransforms
ots1 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty
        then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> FusionM (Maybe FusedSOAC)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
TF.attemptFusion Mode
TF.Vertical Names
preserve (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
        else Maybe FusedSOAC -> FusionM (Maybe FusedSOAC)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe FusedSOAC
forall a. Maybe a
Nothing
    case Maybe FusedSOAC
r of
      Just FusedSOAC
ker' -> do
        let pats2' :: [PatElem Type]
pats2' =
              (VName -> Type -> PatElem Type)
-> [VName] -> [Type] -> [PatElem Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (SOAC SOACS -> [Type]
forall rep. SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
        NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$
          ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode
            (FusedSOAC -> ArrayTransforms
TF.fsOutputTransform FusedSOAC
ker')
            ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2')
            (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker')
            (StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
StmAux (ExpDec SOACS)
aux2)
      Maybe FusedSOAC
Nothing -> Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
vFuseNodeT [EdgeT]
_ [VName]
_ (NodeT, [EdgeT], [EdgeT])
_ (NodeT, [EdgeT])
_ = Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing

resFromLambda :: Lambda rep -> Result
resFromLambda :: forall rep. Lambda rep -> Result
resFromLambda = Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Lambda rep -> Body rep) -> Lambda rep -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody

hasNoDifferingInputs :: [H.Input] -> [H.Input] -> Bool
hasNoDifferingInputs :: [Input] -> [Input] -> Bool
hasNoDifferingInputs [Input]
is1 [Input]
is2 =
  let ([Input]
vs1, [Input]
vs2) = ([Input] -> [Input]
isNotVarInput [Input]
is1, [Input] -> [Input]
isNotVarInput ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ [Input]
is2 [Input] -> [Input] -> [Input]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Input]
is1)
   in [Input] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Input] -> Bool) -> [Input] -> Bool
forall a b. (a -> b) -> a -> b
$ [Input]
vs1 [Input] -> [Input] -> [Input]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` [Input]
vs2

hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT)
hFuseNodeT (SoacNode ArrayTransforms
ots1 Pat Type
pats1 SOAC SOACS
soac1 StmAux (ExpDec SOACS)
aux1) (SoacNode ArrayTransforms
ots2 Pat Type
pats2 SOAC SOACS
soac2 StmAux (ExpDec SOACS)
aux2)
  | ArrayTransforms
ots1 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
    ArrayTransforms
ots2 ArrayTransforms -> ArrayTransforms -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransforms
forall a. Monoid a => a
mempty,
    [Input] -> [Input] -> Bool
hasNoDifferingInputs (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac1) (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac2) = do
      let ker :: FusedSOAC
ker =
            TF.FusedSOAC
              { fsSOAC :: SOAC SOACS
TF.fsSOAC = SOAC SOACS
soac2,
                fsOutputTransform :: ArrayTransforms
TF.fsOutputTransform = ArrayTransforms
forall a. Monoid a => a
mempty,
                fsOutNames :: [VName]
TF.fsOutNames = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats2
              }
          preserve :: Names
preserve = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats1
      Maybe FusedSOAC
r <- Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> FusionM (Maybe FusedSOAC)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
TF.attemptFusion Mode
TF.Horizontal Names
preserve (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pats1) SOAC SOACS
soac1 FusedSOAC
ker
      case Maybe FusedSOAC
r of
        Just FusedSOAC
ker' -> do
          let pats2' :: [PatElem Type]
pats2' =
                (VName -> Type -> PatElem Type)
-> [VName] -> [Type] -> [PatElem Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (FusedSOAC -> [VName]
TF.fsOutNames FusedSOAC
ker') (SOAC SOACS -> [Type]
forall rep. SOAC rep -> [Type]
H.typeOf (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker'))
          NodeT -> FusionM (Maybe NodeT)
fusedSomething (NodeT -> FusionM (Maybe NodeT)) -> NodeT -> FusionM (Maybe NodeT)
forall a b. (a -> b) -> a -> b
$ ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
forall a. Monoid a => a
mempty ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pats2') (FusedSOAC -> SOAC SOACS
TF.fsSOAC FusedSOAC
ker') (StmAux ()
StmAux (ExpDec SOACS)
aux1 StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
StmAux (ExpDec SOACS)
aux2)
        Maybe FusedSOAC
Nothing -> Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing
hFuseNodeT NodeT
_ NodeT
_ = Maybe NodeT -> FusionM (Maybe NodeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NodeT
forall a. Maybe a
Nothing

removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept :: [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
s = case NodeT
s of
  SoacNode ArrayTransforms
ots (Pat [PatElem Type]
pats1) soac :: SOAC SOACS
soac@(H.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_1 [Reduce SOACS]
red_1 Lambda SOACS
lam_1) [Input]
_) StmAux (ExpDec SOACS)
aux1 ->
    ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ [PatElem Type]
pats_unchanged [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. Semigroup a => a -> a -> a
<> [PatElem Type]
pats_new) (Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam_new SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux1
    where
      scan_output_size :: Int
scan_output_size = [Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_1
      red_output_size :: Int
red_output_size = [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
red_1

      ([PatElem Type]
pats_unchanged, [PatElem Type]
pats_toChange) = Int -> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
red_output_size) [PatElem Type]
pats1
      ([(SubExpRes, Type)]
res_unchanged, [(SubExpRes, Type)]
res_toChange) = Int
-> [(SubExpRes, Type)]
-> ([(SubExpRes, Type)], [(SubExpRes, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
scan_output_size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
red_output_size) (Result -> [Type] -> [(SubExpRes, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> Result
forall rep. Lambda rep -> Result
resFromLambda Lambda SOACS
lam_1) (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam_1))

      ([PatElem Type]
pats_new, [(SubExpRes, Type)]
other) = [(PatElem Type, (SubExpRes, Type))]
-> ([PatElem Type], [(SubExpRes, Type)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem Type, (SubExpRes, Type))]
 -> ([PatElem Type], [(SubExpRes, Type)]))
-> [(PatElem Type, (SubExpRes, Type))]
-> ([PatElem Type], [(SubExpRes, Type)])
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, (SubExpRes, Type)) -> Bool)
-> [(PatElem Type, (SubExpRes, Type))]
-> [(PatElem Type, (SubExpRes, Type))]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem Type
x, (SubExpRes, Type)
_) -> PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
x VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
toKeep) ([PatElem Type]
-> [(SubExpRes, Type)] -> [(PatElem Type, (SubExpRes, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pats_toChange [(SubExpRes, Type)]
res_toChange)
      (Result
results, [Type]
types) = [(SubExpRes, Type)] -> (Result, [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, Type)]
res_unchanged [(SubExpRes, Type)] -> [(SubExpRes, Type)] -> [(SubExpRes, Type)]
forall a. [a] -> [a] -> [a]
++ [(SubExpRes, Type)]
other)
      lam_new :: Lambda SOACS
lam_new =
        Lambda SOACS
lam_1
          { lambdaReturnType :: [Type]
lambdaReturnType = [Type]
types,
            lambdaBody :: Body SOACS
lambdaBody = (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_1) {bodyResult :: Result
bodyResult = Result
results}
          }
  NodeT
node -> NodeT
node

vNameFromAdj :: G.Node -> (EdgeT, G.Node) -> VName
vNameFromAdj :: Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1 (EdgeT
edge, Int
n2) = (Int, Int, EdgeT) -> VName
depsFromEdge (Int
n2, Int
n1, EdgeT
edge)

removeUnusedOutputsFromContext :: DepContext -> FusionM DepContext
removeUnusedOutputsFromContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext (Adj EdgeT
incoming, Int
n1, NodeT
nodeT, Adj EdgeT
outgoing) =
  Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
n1, NodeT
nodeT', Adj EdgeT
outgoing)
  where
    toKeep :: [VName]
toKeep = ((EdgeT, Int) -> VName) -> Adj EdgeT -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> (EdgeT, Int) -> VName
vNameFromAdj Int
n1) Adj EdgeT
incoming
    nodeT' :: NodeT
nodeT' = [VName] -> NodeT -> NodeT
removeOutputsExcept [VName]
toKeep NodeT
nodeT

removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs :: DepGraphAug FusionM
removeUnusedOutputs = (Context NodeT EdgeT -> FusionM (Context NodeT EdgeT))
-> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
removeUnusedOutputsFromContext

tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph DepNode
node_to_fuse dg :: DepGraph
dg@DepGraph {dgGraph :: DepGraph -> Gr NodeT EdgeT
dgGraph = Gr NodeT EdgeT
g} = do
  if Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
node_to_fuse_id Gr NodeT EdgeT
g -- Node might have been fused away since.
    then [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs ((Int -> DepGraphAug FusionM) -> [Int] -> [DepGraphAug FusionM]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> DepGraphAug FusionM
vTryFuseNodesInGraph Int
node_to_fuse_id) [Int]
fuses_with) DepGraph
dg
    else DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg
  where
    fuses_with :: [Int]
fuses_with = ((Int, EdgeT) -> Int) -> [(Int, EdgeT)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, EdgeT) -> Int
forall a b. (a, b) -> a
fst ([(Int, EdgeT)] -> [Int]) -> [(Int, EdgeT)] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((Int, EdgeT) -> Bool) -> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a. (a -> Bool) -> [a] -> [a]
filter (EdgeT -> Bool
isDep (EdgeT -> Bool) -> ((Int, EdgeT) -> EdgeT) -> (Int, EdgeT) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, EdgeT) -> EdgeT
forall a b. (a, b) -> b
snd) ([(Int, EdgeT)] -> [(Int, EdgeT)])
-> [(Int, EdgeT)] -> [(Int, EdgeT)]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> Int -> [(Int, EdgeT)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Int -> [(Int, b)]
G.lpre Gr NodeT EdgeT
g (DepNode -> Int
nodeFromLNode DepNode
node_to_fuse)
    node_to_fuse_id :: Int
node_to_fuse_id = DepNode -> Int
nodeFromLNode DepNode
node_to_fuse

doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion :: DepGraphAug FusionM
doVerticalFusion DepGraph
dg = [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs ((DepNode -> DepGraphAug FusionM)
-> [DepNode] -> [DepGraphAug FusionM]
forall a b. (a -> b) -> [a] -> [b]
map DepNode -> DepGraphAug FusionM
tryFuseNodeInGraph ([DepNode] -> [DepGraphAug FusionM])
-> [DepNode] -> [DepGraphAug FusionM]
forall a b. (a -> b) -> a -> b
$ [DepNode] -> [DepNode]
forall a. [a] -> [a]
reverse ([DepNode] -> [DepNode]) -> [DepNode] -> [DepNode]
forall a b. (a -> b) -> a -> b
$ (DepNode -> Bool) -> [DepNode] -> [DepNode]
forall a. (a -> Bool) -> [a] -> [a]
filter DepNode -> Bool
forall {a}. (a, NodeT) -> Bool
relevant ([DepNode] -> [DepNode]) -> [DepNode] -> [DepNode]
forall a b. (a -> b) -> a -> b
$ Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg)) DepGraph
dg
  where
    relevant :: (a, NodeT) -> Bool
relevant (a
_, StmNode {}) = Bool
False
    relevant (a
_, ResNode {}) = Bool
False
    relevant (a, NodeT)
_ = Bool
True

-- | For each pair of SOAC nodes that share an input, attempt to fuse
-- them horizontally.
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion :: DepGraphAug FusionM
doHorizontalFusion DepGraph
dg = [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs [DepGraphAug FusionM]
pairs DepGraph
dg
  where
    pairs :: [DepGraphAug FusionM]
    pairs :: [DepGraphAug FusionM]
pairs = do
      (Int
x, SoacNode ArrayTransforms
_ Pat Type
_ SOAC SOACS
soac_x StmAux (ExpDec SOACS)
_) <- Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (Gr NodeT EdgeT -> [DepNode]) -> Gr NodeT EdgeT -> [DepNode]
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
      (Int
y, SoacNode ArrayTransforms
_ Pat Type
_ SOAC SOACS
soac_y StmAux (ExpDec SOACS)
_) <- Gr NodeT EdgeT -> [DepNode]
forall a b. Gr a b -> [LNode a]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [LNode a]
G.labNodes (Gr NodeT EdgeT -> [DepNode]) -> Gr NodeT EdgeT -> [DepNode]
forall a b. (a -> b) -> a -> b
$ DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg
      Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
y
      -- Must share an input.
      Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$
        (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
          ((VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
H.inputArray (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac_x)) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
H.inputArray)
          (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
H.inputs SOAC SOACS
soac_y)
      DepGraphAug FusionM -> [DepGraphAug FusionM]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DepGraphAug FusionM -> [DepGraphAug FusionM])
-> DepGraphAug FusionM -> [DepGraphAug FusionM]
forall a b. (a -> b) -> a -> b
$ \DepGraph
dg' -> do
        -- Nodes might have been fused away by now.
        if Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
x (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg') Bool -> Bool -> Bool
&& Int -> Gr NodeT EdgeT -> Bool
forall (gr :: * -> * -> *) a b. Graph gr => Int -> gr a b -> Bool
G.gelem Int
y (DepGraph -> Gr NodeT EdgeT
dgGraph DepGraph
dg')
          then Int -> Int -> DepGraphAug FusionM
hTryFuseNodesInGraph Int
x Int
y DepGraph
dg'
          else DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
dg'

doInnerFusion :: DepGraphAug FusionM
doInnerFusion :: DepGraphAug FusionM
doInnerFusion = (Context NodeT EdgeT -> FusionM (Context NodeT EdgeT))
-> DepGraphAug FusionM
forall (m :: * -> *).
Monad m =>
(Context NodeT EdgeT -> m (Context NodeT EdgeT)) -> DepGraphAug m
mapAcross Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext

-- Fixed-point iteration.
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g = do
  Int
prev_fused <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
  DepGraph
g' <- DepGraphAug FusionM
f DepGraph
g
  Int
aft_fused <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
  if Int
prev_fused Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
aft_fused then DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying DepGraphAug FusionM
f DepGraph
g' else DepGraphAug FusionM
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DepGraph
g'

doAllFusion :: DepGraphAug FusionM
doAllFusion :: DepGraphAug FusionM
doAllFusion =
  [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs
    [ DepGraphAug FusionM -> DepGraphAug FusionM
keepTrying (DepGraphAug FusionM -> DepGraphAug FusionM)
-> ([DepGraphAug FusionM] -> DepGraphAug FusionM)
-> [DepGraphAug FusionM]
-> DepGraphAug FusionM
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DepGraphAug FusionM] -> DepGraphAug FusionM
forall (m :: * -> *). Monad m => [DepGraphAug m] -> DepGraphAug m
applyAugs ([DepGraphAug FusionM] -> DepGraphAug FusionM)
-> [DepGraphAug FusionM] -> DepGraphAug FusionM
forall a b. (a -> b) -> a -> b
$
        [ DepGraphAug FusionM
doVerticalFusion,
          DepGraphAug FusionM
doHorizontalFusion,
          DepGraphAug FusionM
doInnerFusion
        ],
      DepGraphAug FusionM
removeUnusedOutputs
    ]

runInnerFusionOnContext :: DepContext -> FusionM DepContext
runInnerFusionOnContext :: Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
runInnerFusionOnContext c :: Context NodeT EdgeT
c@(Adj EdgeT
incoming, Int
node, NodeT
nodeT, Adj EdgeT
outgoing) = case NodeT
nodeT of
  DoNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
body)) [(NodeT, [EdgeT])]
to_fuse ->
    FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT)
-> FusionM (Context NodeT EdgeT)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. Scope SOACS -> FusionM a -> FusionM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([FParam SOACS] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam SOACS, SubExp) -> FParam SOACS)
-> [(FParam SOACS, SubExp)] -> [FParam SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (FParam SOACS, SubExp) -> FParam SOACS
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
params) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> LoopForm SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
form) (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
      Body SOACS
b <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
body [(NodeT, [EdgeT])]
to_fuse
      Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
DoNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
params LoopForm SOACS
form Body SOACS
b)) [], Adj EdgeT
outgoing)
  MatchNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
dec)) [(NodeT, [EdgeT])]
to_fuse -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
    [Case (Body SOACS)]
cases' <- (Case (Body SOACS) -> FusionM (Case (Body SOACS)))
-> [Case (Body SOACS)] -> FusionM [Case (Body SOACS)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> FusionM (Body SOACS))
-> Case (Body SOACS) -> FusionM (Case (Body SOACS))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body SOACS -> FusionM (Body SOACS))
 -> Case (Body SOACS) -> FusionM (Case (Body SOACS)))
-> (Body SOACS -> FusionM (Body SOACS))
-> Case (Body SOACS)
-> FusionM (Case (Body SOACS))
forall a b. (a -> b) -> a -> b
$ Body SOACS -> FusionM (Body SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body SOACS -> FusionM (Body SOACS))
-> (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS
-> FusionM (Body SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
`doFusionWithDelayed` [(NodeT, [EdgeT])]
to_fuse)) [Case (Body SOACS)]
cases
    Body SOACS
defbody' <- Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed Body SOACS
defbody [(NodeT, [EdgeT])]
to_fuse
    Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> [(NodeT, [EdgeT])] -> NodeT
MatchNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Exp SOACS
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body SOACS)]
cases' Body SOACS
defbody' MatchDec (BranchType SOACS)
dec)) [], Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.VJP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Lambda SOACS -> [SubExp] -> [SubExp] -> SOAC SOACS
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
Futhark.VJP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec))), Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Futhark.JVP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Lambda SOACS -> [SubExp] -> [SubExp] -> SOAC SOACS
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
Futhark.JVP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec))), Adj EdgeT
outgoing)
  StmNode (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) -> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT))
-> FusionM (Context NodeT EdgeT) -> FusionM (Context NodeT EdgeT)
forall a b. (a -> b) -> a -> b
$ do
    Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, Stm SOACS -> NodeT
StmNode (Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux ([WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam')), Adj EdgeT
outgoing)
  SoacNode ArrayTransforms
ots Pat Type
pat SOAC SOACS
soac StmAux (ExpDec SOACS)
aux -> do
    let lam :: Lambda SOACS
lam = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
H.lambda SOAC SOACS
soac
    Lambda SOACS
lam' <- Scope SOACS -> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a. Scope SOACS -> FusionM a -> FusionM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) (FusionM (Lambda SOACS) -> FusionM (Lambda SOACS))
-> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ case SOAC SOACS
soac of
      H.Stream {} ->
        FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a. FusionM a -> FusionM a
dontFuseScans (FusionM (Lambda SOACS) -> FusionM (Lambda SOACS))
-> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
      SOAC SOACS
_ ->
        FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a. FusionM a -> FusionM a
doFuseScans (FusionM (Lambda SOACS) -> FusionM (Lambda SOACS))
-> FusionM (Lambda SOACS) -> FusionM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam
    let nodeT' :: NodeT
nodeT' = ArrayTransforms
-> Pat Type -> SOAC SOACS -> StmAux (ExpDec SOACS) -> NodeT
SoacNode ArrayTransforms
ots Pat Type
pat (Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
H.setLambda Lambda SOACS
lam' SOAC SOACS
soac) StmAux (ExpDec SOACS)
aux
    Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj EdgeT
incoming, Int
node, NodeT
nodeT', Adj EdgeT
outgoing)
  NodeT
_ -> Context NodeT EdgeT -> FusionM (Context NodeT EdgeT)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Context NodeT EdgeT
c
  where
    doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
    doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS)
doFusionWithDelayed (Body () Stms SOACS
stms Result
res) [(NodeT, [EdgeT])]
extraNodes = Scope SOACS -> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall a. Scope SOACS -> FusionM a -> FusionM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms) (FusionM (Body SOACS) -> FusionM (Body SOACS))
-> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
      [Stms SOACS]
stm_node <- ((NodeT, [EdgeT]) -> FusionM (Stms SOACS))
-> [(NodeT, [EdgeT])] -> FusionM [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (NodeT -> FusionM (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
NodeT -> m (Stms SOACS)
finalizeNode (NodeT -> FusionM (Stms SOACS))
-> ((NodeT, [EdgeT]) -> NodeT)
-> (NodeT, [EdgeT])
-> FusionM (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeT, [EdgeT]) -> NodeT
forall a b. (a, b) -> a
fst) [(NodeT, [EdgeT])]
extraNodes
      Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph (Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody ([Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat [Stms SOACS]
stm_node Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms) Result
res)
      Body SOACS -> FusionM (Body SOACS)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' Result
res
    doFusionBody :: Body SOACS -> FusionM (Body SOACS)
    doFusionBody :: Body SOACS -> FusionM (Body SOACS)
doFusionBody Body SOACS
body = do
      Stms SOACS
stms' <- Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body
      Body SOACS -> FusionM (Body SOACS)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms'}
    doFusionLambda :: Lambda SOACS -> FusionM (Lambda SOACS)
    doFusionLambda :: Lambda SOACS -> FusionM (Lambda SOACS)
doFusionLambda Lambda SOACS
lam = do
      -- To clean up previous instances of fusion.
      Lambda SOACS
lam' <- Lambda SOACS -> FusionM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
      Int
prev_count <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
      Body SOACS
newbody <- Scope SOACS -> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall a. Scope SOACS -> FusionM a -> FusionM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam') (FusionM (Body SOACS) -> FusionM (Body SOACS))
-> FusionM (Body SOACS) -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> FusionM (Body SOACS)
doFusionBody (Body SOACS -> FusionM (Body SOACS))
-> Body SOACS -> FusionM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
      Int
aft_count <- (FusionEnv -> Int) -> FusionM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets FusionEnv -> Int
fusionCount
      -- To clean up any inner fusion.
      (if Int
prev_count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
aft_count then Lambda SOACS -> FusionM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda else Lambda SOACS -> FusionM (Lambda SOACS)
forall a. a -> FusionM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
        Lambda SOACS
lam' {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
newbody}

-- main fusion function.
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph :: Body SOACS -> FusionM (Stms SOACS)
fuseGraph Body SOACS
body = Scope SOACS -> FusionM (Stms SOACS) -> FusionM (Stms SOACS)
forall a. Scope SOACS -> FusionM a -> FusionM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body)) (FusionM (Stms SOACS) -> FusionM (Stms SOACS))
-> FusionM (Stms SOACS) -> FusionM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
  DepGraph
graph_not_fused <- Body SOACS -> FusionM DepGraph
forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Body SOACS -> m DepGraph
mkDepGraph Body SOACS
body
  DepGraph
graph_fused <- DepGraphAug FusionM
doAllFusion DepGraph
graph_not_fused
  DepGraph -> FusionM (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
DepGraph -> m (Stms SOACS)
linearizeGraph DepGraph
graph_fused

fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts [VName]
outputs Stms SOACS
stms =
  Scope SOACS
-> FusionEnv -> FusionM (Stms SOACS) -> PassM (Stms SOACS)
forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
    (Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms)
    FusionEnv
freshFusionEnv
    (Body SOACS -> FusionM (Stms SOACS)
fuseGraph (Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms ([VName] -> Result
varsRes [VName]
outputs)))

fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
  Stms SOACS
fun_stms' <-
    Scope SOACS
-> FusionEnv -> FusionM (Stms SOACS) -> PassM (Stms SOACS)
forall (m :: * -> *) a.
MonadFreshNames m =>
Scope SOACS -> FusionEnv -> FusionM a -> m a
runFusionM
      (FunDef SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fun Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts)
      FusionEnv
freshFusionEnv
      (Body SOACS -> FusionM (Stms SOACS)
fuseGraph (FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fun))
  FunDef SOACS -> PassM (FunDef SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef SOACS
fun {funDefBody :: Body SOACS
funDefBody = (FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
fun_stms'}}

-- | The pass definition.
{-# NOINLINE fuseSOACs #-}
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
  Pass
    { passName :: String
passName = String
"Fuse SOACs",
      passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
p ->
        (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          ([VName] -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [FunDef SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
p)))
          Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
          Prog SOACS
p
    }