{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform horizontal and vertical fusion of SOACs.  See the paper
-- /A T2 Graph-Reduction Approach To Fusion/ for the basic idea (some
-- extensions discussed in /Design and GPGPU Performance of Futhark’s
-- Redomap Construct/).
module Futhark.Optimise.Fusion (fuseSOACs) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import qualified Futhark.IR.Aliases as Aliases
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.IR.SOACS.Simplify
import Futhark.Optimise.Fusion.LoopKernel
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (maxinum)

data VarEntry
  = IsArray VName (NameInfo SOACS) Names SOAC.Input
  | IsNotArray (NameInfo SOACS)

varEntryType :: VarEntry -> NameInfo SOACS
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType (IsArray VName
_ NameInfo SOACS
dec Names
_ Input
_) =
  NameInfo SOACS
dec
varEntryType (IsNotArray NameInfo SOACS
dec) =
  NameInfo SOACS
dec

varEntryAliases :: VarEntry -> Names
varEntryAliases :: VarEntry -> Names
varEntryAliases (IsArray VName
_ NameInfo SOACS
_ Names
x Input
_) = Names
x
varEntryAliases VarEntry
_ = Names
forall a. Monoid a => a
mempty

data FusionGEnv = FusionGEnv
  { -- | Mapping from variable name to its entire family.
    FusionGEnv -> Map VName [VName]
soacs :: M.Map VName [VName],
    FusionGEnv -> Map VName VarEntry
varsInScope :: M.Map VName VarEntry,
    FusionGEnv -> FusedRes
fusedRes :: FusedRes
  }

lookupArr :: VName -> FusionGEnv -> Maybe SOAC.Input
lookupArr :: VName -> FusionGEnv -> Maybe Input
lookupArr VName
v FusionGEnv
env = VarEntry -> Maybe Input
asArray (VarEntry -> Maybe Input) -> Maybe VarEntry -> Maybe Input
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
  where
    asArray :: VarEntry -> Maybe Input
asArray (IsArray VName
_ NameInfo SOACS
_ Names
_ Input
input) = Input -> Maybe Input
forall a. a -> Maybe a
Just Input
input
    asArray IsNotArray {} = Maybe Input
forall a. Maybe a
Nothing

newtype Error = Error String

instance Show Error where
  show :: Error -> String
show (Error String
msg) = String
"Fusion error:\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg

newtype FusionGM a = FusionGM (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a)
  deriving
    ( Applicative FusionGM
a -> FusionGM a
Applicative FusionGM
-> (forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a. a -> FusionGM a)
-> Monad FusionGM
FusionGM a -> (a -> FusionGM b) -> FusionGM b
FusionGM a -> FusionGM b -> FusionGM b
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> FusionGM a
$creturn :: forall a. a -> FusionGM a
>> :: FusionGM a -> FusionGM b -> FusionGM b
$c>> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
>>= :: FusionGM a -> (a -> FusionGM b) -> FusionGM b
$c>>= :: forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
$cp1Monad :: Applicative FusionGM
Monad,
      Functor FusionGM
a -> FusionGM a
Functor FusionGM
-> (forall a. a -> FusionGM a)
-> (forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b c.
    (a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM a)
-> Applicative FusionGM
FusionGM a -> FusionGM b -> FusionGM b
FusionGM a -> FusionGM b -> FusionGM a
FusionGM (a -> b) -> FusionGM a -> FusionGM b
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: FusionGM a -> FusionGM b -> FusionGM a
$c<* :: forall a b. FusionGM a -> FusionGM b -> FusionGM a
*> :: FusionGM a -> FusionGM b -> FusionGM b
$c*> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
liftA2 :: (a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
<*> :: FusionGM (a -> b) -> FusionGM a -> FusionGM b
$c<*> :: forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
pure :: a -> FusionGM a
$cpure :: forall a. a -> FusionGM a
$cp1Applicative :: Functor FusionGM
Applicative,
      a -> FusionGM b -> FusionGM a
(a -> b) -> FusionGM a -> FusionGM b
(forall a b. (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b. a -> FusionGM b -> FusionGM a) -> Functor FusionGM
forall a b. a -> FusionGM b -> FusionGM a
forall a b. (a -> b) -> FusionGM a -> FusionGM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> FusionGM b -> FusionGM a
$c<$ :: forall a b. a -> FusionGM b -> FusionGM a
fmap :: (a -> b) -> FusionGM a -> FusionGM b
$cfmap :: forall a b. (a -> b) -> FusionGM a -> FusionGM b
Functor,
      MonadError Error,
      MonadState VNameSource,
      MonadReader FusionGEnv
    )

instance MonadFreshNames FusionGM where
  getNameSource :: FusionGM VNameSource
getNameSource = FusionGM VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
  putNameSource :: VNameSource -> FusionGM ()
putNameSource = VNameSource -> FusionGM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance HasScope SOACS FusionGM where
  askScope :: FusionGM (Scope SOACS)
askScope = (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS))
-> (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> Scope SOACS
forall k. Map k VarEntry -> Map k (NameInfo SOACS)
toScope (Map VName VarEntry -> Scope SOACS)
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
    where
      toScope :: Map k VarEntry -> Map k (NameInfo SOACS)
toScope = (VarEntry -> NameInfo SOACS)
-> Map k VarEntry -> Map k (NameInfo SOACS)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry -> NameInfo SOACS
varEntryType

------------------------------------------------------------------------
--- Monadic Helpers: bind/new/runFusionGatherM, etc
------------------------------------------------------------------------

-- | Binds an array name to the set of used-array vars
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (Ident VName
name Type
t, Names
aliases) =
  FusionGEnv
env {varsInScope :: Map VName VarEntry
varsInScope = VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry
entry (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env}
  where
    entry :: VarEntry
entry = case Type
t of
      Array {} -> VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
name (LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
t) Names
aliases' (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$ Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
      Type
_ -> NameInfo SOACS -> VarEntry
IsNotArray (NameInfo SOACS -> VarEntry) -> NameInfo SOACS -> VarEntry
forall a b. (a -> b) -> a -> b
$ LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
t
    expand :: VName -> Names
expand = Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases (Maybe VarEntry -> Names)
-> (VName -> Maybe VarEntry) -> VName -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> VName -> Maybe VarEntry
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
    aliases' :: Names
aliases' = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expand ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases)

bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars = (FusionGEnv -> (Ident, Names) -> FusionGEnv)
-> FusionGEnv -> [(Ident, Names)] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar

binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
vs = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (FusionGEnv -> [(Ident, Names)] -> FusionGEnv
`bindVars` [(Ident, Names)]
vs)

gatherStmPat :: Pat -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPat :: Pat -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPat Pat
pat Exp
e = [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes)
-> [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
idents [Names]
aliases
  where
    idents :: [Ident]
idents = PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
pat
    aliases :: [Names]
aliases = Exp (Aliases SOACS) -> [Names]
forall rep. Aliased rep => Exp rep -> [Names]
expAliases (AliasTable -> Exp -> Exp (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Exp rep -> Exp (Aliases rep)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp
e)

bindingPat :: Pat -> FusionGM a -> FusionGM a
bindingPat :: Pat -> FusionGM a -> FusionGM a
bindingPat = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> (PatT Type -> [(Ident, Names)])
-> PatT Type
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> (PatT Type -> [Ident]) -> PatT Type -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents

bindingParams :: Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams :: [Param t] -> FusionGM a -> FusionGM a
bindingParams = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> ([Param t] -> [(Ident, Names)])
-> [Param t]
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> ([Param t] -> [Ident]) -> [Param t] -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param t -> Ident) -> [Param t] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param t -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent

-- | Binds an array name to the set of soac-produced vars
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
faml FusionGEnv
env (Ident VName
nm Type
t) =
  FusionGEnv
env
    { soacs :: Map VName [VName]
soacs = VName -> [VName] -> Map VName [VName] -> Map VName [VName]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm [VName]
faml (Map VName [VName] -> Map VName [VName])
-> Map VName [VName] -> Map VName [VName]
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName [VName]
soacs FusionGEnv
env,
      varsInScope :: Map VName VarEntry
varsInScope =
        VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
          VName
nm
          ( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
nm (LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
t) Names
forall a. Monoid a => a
mempty (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
              Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm Type
t
          )
          (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
    }

varAliases :: VName -> FusionGM Names
varAliases :: VName -> FusionGM Names
varAliases VName
v =
  (FusionGEnv -> Names) -> FusionGM Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Names) -> FusionGM Names)
-> (FusionGEnv -> Names) -> FusionGM Names
forall a b. (a -> b) -> a -> b
$
    (VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> (FusionGEnv -> Names) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases
      (Maybe VarEntry -> Names)
-> (FusionGEnv -> Maybe VarEntry) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v
      (Map VName VarEntry -> Maybe VarEntry)
-> (FusionGEnv -> Map VName VarEntry)
-> FusionGEnv
-> Maybe VarEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope

varsAliases :: Names -> FusionGM Names
varsAliases :: Names -> FusionGM Names
varsAliases = ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (FusionGM [Names] -> FusionGM Names)
-> (Names -> FusionGM [Names]) -> Names -> FusionGM Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases ([VName] -> FusionGM [Names])
-> (Names -> [VName]) -> Names -> FusionGM [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
ip_vs, [VName]
other_infuse_vs) = do
  FusedRes
res' <- (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
res ([VName]
ip_vs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
other_infuse_vs)
  Names
aliases <- [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases [VName]
ip_vs
  let inspectKer :: FusedKer -> FusedKer
inspectKer FusedKer
k = FusedKer
k {inplace :: Names
inplace = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedKer -> Names
inplace FusedKer
k}
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res' {kernels :: Map KernName FusedKer
kernels = (FusedKer -> FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b k. (a -> b) -> Map k a -> Map k b
M.map FusedKer -> FusedKer
inspectKer (Map KernName FusedKer -> Map KernName FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
res'}

checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates FusedRes
res (BasicOp (Update Safety
_ VName
src Slice SubExp
slice SubExp
_)) = do
  let ifvs :: [VName]
ifvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice
  FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName
src], [VName]
ifvs)
checkForUpdates FusedRes
res (BasicOp (FlatUpdate VName
src FlatSlice SubExp
slice VName
_)) = do
  let ifvs :: [VName]
ifvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ FlatSlice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn FlatSlice SubExp
slice
  FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName
src], [VName]
ifvs)
checkForUpdates FusedRes
res (Op (Futhark.Scatter _ _ _ written_info)) = do
  let updt_arrs :: [VName]
updt_arrs = ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
x) -> VName
x) [(Shape, Int, VName)]
written_info
  FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
updt_arrs, [])
checkForUpdates FusedRes
res Exp
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res

-- | Updates the environment: (i) the @soacs@ (map) by binding each pattern
--   element identifier to all pattern elements (identifiers) and (ii) the
--   variables in scope (map) by inserting each (pattern-array) name.
--   Finally, if the binding is an in-place update, then the @inplace@ field
--   of each (result) kernel is updated with the new in-place updates.
bindingFamily :: Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily :: Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pat
pat = (FusionGEnv -> FusionGEnv)
-> FusionGM FusedRes -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local FusionGEnv -> FusionGEnv
bind
  where
    idents :: [Ident]
idents = PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
pat
    family :: [VName]
family = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
    bind :: FusionGEnv -> FusionGEnv
bind FusionGEnv
env = (FusionGEnv -> Ident -> FusionGEnv)
-> FusionGEnv -> [Ident] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
family) FusionGEnv
env [Ident]
idents

bindingTransform :: PatElem -> VName -> SOAC.ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform :: PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElem
pe VName
srcname ArrayTransform
trns = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a)
-> (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall a b. (a -> b) -> a -> b
$ \FusionGEnv
env ->
  case VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
srcname (Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> Maybe VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env of
    Just (IsArray VName
src' NameInfo SOACS
_ Names
aliases Input
input) ->
      FusionGEnv
env
        { varsInScope :: Map VName VarEntry
varsInScope =
            VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
              VName
vname
              ( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
src' (LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
dec) (VName -> Names
oneName VName
srcname Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases) (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
                  ArrayTransform
trns ArrayTransform -> Input -> Input
`SOAC.addTransform` Input
input
              )
              (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
        }
    Maybe VarEntry
_ -> FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (PatElemT Type -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT Type
PatElem
pe, VName -> Names
oneName VName
vname)
  where
    vname :: VName
vname = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
pe
    dec :: Type
dec = PatElemT Type -> Type
forall dec. PatElemT dec -> dec
patElemDec PatElemT Type
PatElem
pe

-- | Binds the fusion result to the environment.
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
rrr = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\FusionGEnv
x -> FusionGEnv
x {fusedRes :: FusedRes
fusedRes = FusedRes
rrr})

-- | The fusion transformation runs in this monad.  The mutable
-- state refers to the fresh-names engine.
-- The reader hides the vtable that associates ... to ... (fill in, please).
-- The 'Either' monad is used for error handling.
runFusionGatherM ::
  MonadFreshNames m =>
  FusionGM a ->
  FusionGEnv ->
  m (Either Error a)
runFusionGatherM :: FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM (FusionGM ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) FusionGEnv
env =
  (VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Either Error a, VNameSource))
 -> m (Either Error a))
-> (VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src -> Reader FusionGEnv (Either Error a, VNameSource)
-> FusionGEnv -> (Either Error a, VNameSource)
forall r a. Reader r a -> r -> a
runReader (StateT VNameSource (Reader FusionGEnv) (Either Error a)
-> VNameSource -> Reader FusionGEnv (Either Error a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
-> StateT VNameSource (Reader FusionGEnv) (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) VNameSource
src) FusionGEnv
env

------------------------------------------------------------------------
--- Fusion Entry Points: gather the to-be-fused kernels@pgm level    ---
---    and fuse them in a second pass!                               ---
------------------------------------------------------------------------

-- | The pass definition.
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
  Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: String
passName = String
"Fuse SOACs",
      passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
prog ->
        Prog SOACS -> PassM (Prog SOACS)
simplifySOACS (Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Prog SOACS -> PassM (Prog SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Prog rep -> m (Prog rep)
renameProg
          (Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
            (Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts ([FunDef SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog)))
            Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
            Prog SOACS
prog
    }

fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts Names
used_consts Stms SOACS
consts =
  Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms Scope SOACS
forall a. Monoid a => a
mempty Stms SOACS
consts (Result -> PassM (Stms SOACS)) -> Result -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
used_consts

fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
  Stms SOACS
stms <-
    Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms
      (Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (FunDef SOACS -> [FParam SOACS]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef SOACS
fun))
      (BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef SOACS
fun)
      (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef SOACS
fun)
  let body :: BodyT SOACS
body = (FunDef SOACS -> BodyT SOACS
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms}
  FunDef SOACS -> PassM (FunDef SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SOACS
fun {funDefBody :: BodyT SOACS
funDefBody = BodyT SOACS
body}

fuseStms :: Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms :: Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms Scope SOACS
scope Stms SOACS
stms Result
res = do
  let env :: FusionGEnv
env =
        FusionGEnv :: Map VName [VName] -> Map VName VarEntry -> FusedRes -> FusionGEnv
FusionGEnv
          { soacs :: Map VName [VName]
soacs = Map VName [VName]
forall k a. Map k a
M.empty,
            varsInScope :: Map VName VarEntry
varsInScope = Map VName VarEntry
forall a. Monoid a => a
mempty,
            fusedRes :: FusedRes
fusedRes = FusedRes
forall a. Monoid a => a
mempty
          }
  FusedRes
k <-
    FusedRes -> FusedRes
cleanFusionResult
      (FusedRes -> FusedRes) -> PassM FusedRes -> PassM FusedRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PassM (Either Error FusedRes) -> PassM FusedRes
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM
        ( FusionGM FusedRes -> FusionGEnv -> PassM (Either Error FusedRes)
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM
            ([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
forall a. Monoid a => a
mempty (Stms SOACS -> [Stm]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms) Result
res)
            FusionGEnv
env
        )
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FusedRes -> Bool
rsucc FusedRes
k
    then Stms SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
stms
    else PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM (PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS))
-> PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusionGM (Stms SOACS)
-> FusionGEnv -> PassM (Either Error (Stms SOACS))
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM ([(Ident, Names)] -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusedRes -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
k (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms) FusionGEnv
env
  where
    scope' :: [(Ident, Names)]
scope' = ((VName, NameInfo SOACS) -> (Ident, Names))
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, NameInfo SOACS) -> (Ident, Names)
forall t b. (Typed t, Monoid b) => (VName, t) -> (Ident, b)
toBind ([(VName, NameInfo SOACS)] -> [(Ident, Names)])
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> [(VName, NameInfo SOACS)]
forall k a. Map k a -> [(k, a)]
M.toList Scope SOACS
scope
    toBind :: (VName, t) -> (Ident, b)
toBind (VName
k, t
t) = (VName -> Type -> Ident
Ident VName
k (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
t, b
forall a. Monoid a => a
mempty)

---------------------------------------------------
---------------------------------------------------
---- RESULT's Data Structure
---------------------------------------------------
---------------------------------------------------

-- | A type used for (hopefully) uniquely referring a producer SOAC.
-- The uniquely identifying value is the name of the first array
-- returned from the SOAC.
newtype KernName = KernName {KernName -> VName
unKernName :: VName}
  deriving (KernName -> KernName -> Bool
(KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool) -> Eq KernName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernName -> KernName -> Bool
$c/= :: KernName -> KernName -> Bool
== :: KernName -> KernName -> Bool
$c== :: KernName -> KernName -> Bool
Eq, Eq KernName
Eq KernName
-> (KernName -> KernName -> Ordering)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> KernName)
-> (KernName -> KernName -> KernName)
-> Ord KernName
KernName -> KernName -> Bool
KernName -> KernName -> Ordering
KernName -> KernName -> KernName
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernName -> KernName -> KernName
$cmin :: KernName -> KernName -> KernName
max :: KernName -> KernName -> KernName
$cmax :: KernName -> KernName -> KernName
>= :: KernName -> KernName -> Bool
$c>= :: KernName -> KernName -> Bool
> :: KernName -> KernName -> Bool
$c> :: KernName -> KernName -> Bool
<= :: KernName -> KernName -> Bool
$c<= :: KernName -> KernName -> Bool
< :: KernName -> KernName -> Bool
$c< :: KernName -> KernName -> Bool
compare :: KernName -> KernName -> Ordering
$ccompare :: KernName -> KernName -> Ordering
$cp1Ord :: Eq KernName
Ord, Int -> KernName -> ShowS
[KernName] -> ShowS
KernName -> String
(Int -> KernName -> ShowS)
-> (KernName -> String) -> ([KernName] -> ShowS) -> Show KernName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernName] -> ShowS
$cshowList :: [KernName] -> ShowS
show :: KernName -> String
$cshow :: KernName -> String
showsPrec :: Int -> KernName -> ShowS
$cshowsPrec :: Int -> KernName -> ShowS
Show)

data FusedRes = FusedRes
  { -- | Whether we have fused something anywhere.
    FusedRes -> Bool
rsucc :: Bool,
    -- | Associates an array to the name of the
    -- SOAC kernel that has produced it.
    FusedRes -> Map VName KernName
outArr :: M.Map VName KernName,
    -- | Associates an array to the names of the
    -- SOAC kernels that uses it. These sets include
    -- only the SOAC input arrays used as full variables, i.e., no `a[i]'.
    FusedRes -> Map VName (Set KernName)
inpArr :: M.Map VName (S.Set KernName),
    -- | the (names of) arrays that are not fusible, i.e.,
    --
    --   1. they are either used other than input to SOAC kernels, or
    --
    --   2. are used as input to at least two different kernels that
    --      are not located on disjoint control-flow branches, or
    --
    --   3. are used in the lambda expression of SOACs
    FusedRes -> Names
infusible :: Names,
    -- | The map recording the uses
    FusedRes -> Map KernName FusedKer
kernels :: M.Map KernName FusedKer
  }

instance Semigroup FusedRes where
  FusedRes
res1 <> :: FusedRes -> FusedRes -> FusedRes
<> FusedRes
res2 =
    Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      (FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
      (FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
      ((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
      (FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2)
      (FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)

instance Monoid FusedRes where
  mempty :: FusedRes
mempty =
    FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      { rsucc :: Bool
rsucc = Bool
False,
        outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
        inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
        infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
        kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
      }

isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers :: FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
ress Set KernName
kers VName
nm =
  case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress) of
    Maybe (Set KernName)
Nothing -> Bool
False
    Just Set KernName
s -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName -> Bool
forall a. Set a -> Bool
S.null (Set KernName -> Bool) -> Set KernName -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName
s Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set KernName
kers

getKersWithInpArrs :: FusedRes -> [VName] -> S.Set KernName
getKersWithInpArrs :: FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
ress =
  [Set KernName] -> Set KernName
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set KernName] -> Set KernName)
-> ([VName] -> [Set KernName]) -> [VName] -> Set KernName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Maybe (Set KernName)) -> [VName] -> [Set KernName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress)

-- | extend the set of names to include all the names
--     produced via SOACs (by querring the vtable's soac)
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr =
  ([VName] -> VName -> FusionGM [VName])
-> [VName] -> [VName] -> FusionGM [VName]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
    ( \[VName]
y VName
nm -> do
        Maybe [VName]
stm <- (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName]))
-> (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall a b. (a -> b) -> a -> b
$ VName -> Map VName [VName] -> Maybe [VName]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (Map VName [VName] -> Maybe [VName])
-> (FusionGEnv -> Map VName [VName]) -> FusionGEnv -> Maybe [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName [VName]
soacs
        case Maybe [VName]
stm of
          Maybe [VName]
Nothing -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
nm])
          Just [VName]
nns -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
nns)
    )
    []

----------------------------------------------------------------------
----------------------------------------------------------------------

soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac = do
  let ([VName]
inp_idds, [VName]
other_idds) = [Input] -> ([VName], [VName])
getIdentArr ([Input] -> ([VName], [VName])) -> [Input] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
      ([VName]
inp_nms0, [VName]
other_nms0) = ([VName]
inp_idds, [VName]
other_idds)
  [VName]
inp_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
inp_nms0
  [VName]
other_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
other_nms0
  ([VName], [VName]) -> FusionGM ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
inp_nms, [VName]
other_nms)

addNewKerWithInfusible :: FusedRes -> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible :: FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res ([Ident]
idd, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs = do
  KernName
nm_ker <- VName -> KernName
KernName (VName -> KernName) -> FusionGM VName -> FusionGM KernName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ker"
  Scope SOACS
scope <- FusionGM (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let out_nms :: [VName]
out_nms = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
idd
      new_ker :: FusedKer
new_ker = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope
      comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union
      os' :: Map VName KernName
os' =
        [(VName, KernName)] -> Map VName KernName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
arr, KernName
nm_ker) | VName
arr <- [VName]
out_nms]
          Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res
      is' :: Map VName (Set KernName)
is' =
        [(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
          [ (VName
arr, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
nm_ker)
            | VName
arr <- (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
          ]
          Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
    Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      (FusedRes -> Bool
rsucc FusedRes
res)
      Map VName KernName
os'
      Map VName (Set KernName)
is'
      Names
ufs
      (KernName
-> FusedKer -> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernName
nm_ker FusedKer
new_ker (FusedRes -> Map KernName FusedKer
kernels FusedRes
res))

lookupInput :: VName -> FusionGM (Maybe SOAC.Input)
lookupInput :: VName -> FusionGM (Maybe Input)
lookupInput VName
name = (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name

inlineSOACInput :: SOAC.Input -> FusionGM SOAC.Input
inlineSOACInput :: Input -> FusionGM Input
inlineSOACInput (SOAC.Input ArrayTransforms
ts VName
v Type
t) = do
  Maybe Input
maybe_inp <- VName -> FusionGM (Maybe Input)
lookupInput VName
v
  case Maybe Input
maybe_inp of
    Maybe Input
Nothing ->
      Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v Type
t
    Just (SOAC.Input ArrayTransforms
ts2 VName
v2 Type
t2) ->
      Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input (ArrayTransforms
ts2 ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts) VName
v2 Type
t2

inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
soac = do
  [Input]
inputs' <- (Input -> FusionGM Input) -> [Input] -> FusionGM [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Input -> FusionGM Input
inlineSOACInput ([Input] -> FusionGM [Input]) -> [Input] -> FusionGM [Input]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
  SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ [Input]
inputs' [Input] -> SOAC -> SOAC
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC
soac

-- | Attempts to fuse between SOACs. Input:
--   @rem_stms@ are the bindings remaining in the current body after @orig_soac@.
--   @lam_used_nms@ the infusible names
--   @res@ the fusion result (before processing the current soac)
--   @orig_soac@ and @out_idds@ the current SOAC and its binding pattern
--   @consumed@ is the set of names consumed by the SOAC.
--   Output: a new Fusion Result (after processing the current SOAC binding)
greedyFuse ::
  [Stm] ->
  Names ->
  FusedRes ->
  (Pat, StmAux (), SOAC, Names) ->
  FusionGM FusedRes
greedyFuse :: [Stm]
-> Names
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_stms Names
lam_used_nms FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
orig_soac, Names
consumed) = do
  SOAC
soac <- SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
orig_soac
  ([VName]
inp_nms, [VName]
other_nms) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
  -- Assumption: the free vars in lambda are already in @infusible res@.
  let out_nms :: [VName]
out_nms = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds
      isInfusible :: VName -> Bool
isInfusible = (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res)
      is_screma :: Bool
is_screma = case SOAC
orig_soac of
        SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_ ->
          (Maybe ([Reduce SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form))
            Bool -> Bool -> Bool
&& Bool -> Bool
not (Maybe [Reduce SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe [Scan SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form))
        SOAC
_ -> Bool
False
  --
  -- Conditions for fusion:
  -- If current soac is a replicate OR (current soac a redomap/scanomap AND
  --    (i) none of @out_idds@ belongs to the infusible set)
  -- THEN try applying producer-consumer fusion
  -- ELSE try applying horizontal        fusion
  -- (without duplicating computation in both cases)

  (Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
fused_nms, [FusedKer]
old_kers, [KernName]
oldker_nms) <-
    if Bool
is_screma Bool -> Bool -> Bool
|| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isInfusible [VName]
out_nms
      then [Stm]
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_stms FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
      else FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
  --
  -- (ii) check whether fusing @soac@ will violate any in-place update
  --      restriction, e.g., would move an input array past its in-place update.
  Names
all_used_names <-
    ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (FusionGM [Names] -> FusionGM Names)
-> (Names -> FusionGM [Names]) -> Names -> FusionGM Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases ([VName] -> FusionGM [Names])
-> (Names -> [VName]) -> Names -> FusionGM [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> FusionGM Names) -> Names -> FusionGM Names
forall a b. (a -> b) -> a -> b
$
      [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat [Names
lam_used_nms, [VName] -> Names
namesFromList [VName]
inp_nms, [VName] -> Names
namesFromList [VName]
other_nms]
  let has_inplace :: FusedKer -> Bool
has_inplace FusedKer
ker = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Bool
`namesIntersect` Names
all_used_names
      ok_inplace :: Bool
ok_inplace = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (FusedKer -> Bool) -> [FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any FusedKer -> Bool
has_inplace [FusedKer]
old_kers
  --
  -- (iii)  there are some kernels that use some of `out_idds' as inputs
  -- (iv)   and producer-consumer or horizontal fusion succeeds with those.
  let fusible_ker :: Bool
fusible_ker = Bool -> Bool
not ([FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FusedKer]
old_kers) Bool -> Bool -> Bool
&& Bool
ok_inplace Bool -> Bool -> Bool
&& Bool
ok_kers_compat
  --
  -- Start constructing the fusion's result:
  --  (i) inparr ids other than vars will be added to infusible list,
  -- (ii) will also become part of the infusible set the inparr vars
  --         that also appear as inparr of another kernel,
  --         BUT which said kernel is not the one we are fusing with (now)!
  let mod_kerS :: Set KernName
mod_kerS = if Bool
fusible_ker then [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList [KernName]
oldker_nms else Set KernName
forall a. Monoid a => a
mempty
  let used_inps :: [VName]
used_inps = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
res Set KernName
mod_kerS) [VName]
inp_nms
  let ufs :: Names
ufs =
        [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
          [ FusedRes -> Names
infusible FusedRes
res,
            [VName] -> Names
namesFromList [VName]
used_inps,
            [VName] -> Names
namesFromList [VName]
other_nms
              Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac)
          ]
  let comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union

  if Bool -> Bool
not Bool
fusible_ker
    then FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res (PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs
    else do
      -- Need to suitably update `inpArr':
      --   (i) first remove the inpArr bindings of the old kernel
      let inpArr' :: Map VName (Set KernName)
inpArr' =
            (Map VName (Set KernName)
 -> (FusedKer, KernName) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(FusedKer, KernName)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \Map VName (Set KernName)
inpa (FusedKer
kold, KernName
knm) ->
                  (Map VName (Set KernName) -> VName -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> Set VName
-> Map VName (Set KernName)
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl'
                    ( \Map VName (Set KernName)
inpp VName
nm ->
                        case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm Map VName (Set KernName)
inpp of
                          Maybe (Set KernName)
Nothing -> Map VName (Set KernName)
inpp
                          Just Set KernName
s ->
                            let new_set :: Set KernName
new_set = KernName -> Set KernName -> Set KernName
forall a. Ord a => a -> Set a -> Set a
S.delete KernName
knm Set KernName
s
                             in if Set KernName -> Bool
forall a. Set a -> Bool
S.null Set KernName
new_set
                                  then VName -> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
nm Map VName (Set KernName)
inpp
                                  else VName
-> Set KernName
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm Set KernName
new_set Map VName (Set KernName)
inpp
                    )
                    Map VName (Set KernName)
inpa
                    (Set VName -> Map VName (Set KernName))
-> Set VName -> Map VName (Set KernName)
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
kold
              )
              (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res)
              ([FusedKer] -> [KernName] -> [(FusedKer, KernName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FusedKer]
old_kers [KernName]
oldker_nms)
      --  (ii) then add the inpArr bindings of the new kernel
      let fused_ker_nms :: [(KernName, FusedKer)]
fused_ker_nms = [KernName] -> [FusedKer] -> [(KernName, FusedKer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [KernName]
fused_nms [FusedKer]
fused_kers
          inpArr'' :: Map VName (Set KernName)
inpArr'' =
            (Map VName (Set KernName)
 -> (KernName, FusedKer) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(KernName, FusedKer)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \Map VName (Set KernName)
inpa' (KernName
knm, FusedKer
knew) ->
                  [(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                    [ (VName
k, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
knm)
                      | VName
k <- Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
knew
                    ]
                    Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` Map VName (Set KernName)
inpa'
              )
              Map VName (Set KernName)
inpArr'
              [(KernName, FusedKer)]
fused_ker_nms
      -- Update the kernels map (why not delete the ones that have been fused?)
      let kernels' :: Map KernName FusedKer
kernels' = [(KernName, FusedKer)] -> Map KernName FusedKer
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(KernName, FusedKer)]
fused_ker_nms Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res
      -- nothing to do for `outArr' (since we have not added a new kernel)
      -- DO IMPROVEMENT: attempt to fuse the resulting kernel AGAIN until it fails,
      --                 but make sure NOT to add a new kernel!
      FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes Bool
True (FusedRes -> Map VName KernName
outArr FusedRes
res) Map VName (Set KernName)
inpArr'' Names
ufs Map KernName FusedKer
kernels'

prodconsGreedyFuse ::
  FusedRes ->
  (Pat, StmAux (), SOAC, Names) ->
  FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse :: FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
  let out_nms :: [VName]
out_nms = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds -- Extract VNames from output patterns
      to_fuse_knmSet :: Set KernName
to_fuse_knmSet = FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res [VName]
out_nms -- Find kernels which consume outputs
      to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList Set KernName
to_fuse_knmSet
      lookup_kern :: KernName -> FusionGM FusedKer
lookup_kern KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
        Maybe FusedKer
Nothing ->
          Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
            String -> Error
Error
              ( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
              )
        Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker
  [FusedKer]
to_fuse_kers <- (KernName -> FusionGM FusedKer)
-> [KernName] -> FusionGM [FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernName -> FusionGM FusedKer
lookup_kern [KernName]
to_fuse_knms -- Get all consumer kernels
  -- try producer-consumer fusion
  (Bool
ok_kers_compat, [FusedKer]
fused_kers) <- do
    [Maybe FusedKer]
kers <-
      [FusedKer]
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [FusedKer]
to_fuse_kers ((FusedKer -> FusionGM (Maybe FusedKer))
 -> FusionGM [Maybe FusedKer])
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall a b. (a -> b) -> a -> b
$
        Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
forall a. Monoid a => a
mempty (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds) SOAC
soac Names
consumed
    case [Maybe FusedKer] -> Maybe [FusedKer]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe FusedKer]
kers of
      Maybe [FusedKer]
Nothing -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, [])
      Just [FusedKer]
kers' -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, (FusedKer -> FusedKer) -> [FusedKer] -> [FusedKer]
forall a b. (a -> b) -> [a] -> [b]
map FusedKer -> FusedKer
certifyKer [FusedKer]
kers')
  (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
to_fuse_knms, [FusedKer]
to_fuse_kers, [KernName]
to_fuse_knms)
  where
    certifyKer :: FusedKer -> FusedKer
certifyKer FusedKer
k = FusedKer
k {kerAux :: StmAux ()
kerAux = FusedKer -> StmAux ()
kerAux FusedKer
k StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux}

horizontGreedyFuse ::
  [Stm] ->
  FusedRes ->
  (Pat, StmAux (), SOAC, Names) ->
  FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse :: [Stm]
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_stms FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
  ([VName]
inp_nms, [VName]
_) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
  let out_nms :: [VName]
out_nms = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds
      infusible_nms :: Names
infusible_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res) [VName]
out_nms
      out_arr_nms :: [VName]
out_arr_nms = case SOAC
soac of
        -- the accumulator result cannot be fused!
        SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
_) [Input]
_ ->
          Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan SOACS]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce SOACS]
reds) [VName]
out_nms
        SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_ -> Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
out_nms
        SOAC
_ -> [VName]
out_nms
      to_fuse_knms1 :: [KernName]
to_fuse_knms1 = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res ([VName]
out_arr_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
inp_nms)
      to_fuse_knms2 :: [KernName]
to_fuse_knms2 = SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize (SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac) FusedRes
res
      to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList ([KernName] -> Set KernName) -> [KernName] -> Set KernName
forall a b. (a -> b) -> a -> b
$ [KernName]
to_fuse_knms1 [KernName] -> [KernName] -> [KernName]
forall a. [a] -> [a] -> [a]
++ [KernName]
to_fuse_knms2
      lookupKernel :: KernName -> FusionGM FusedKer
lookupKernel KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
        Maybe FusedKer
Nothing ->
          Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
            String -> Error
Error
              ( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
              )
        Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker

  -- For each kernel get the index in the bindings where the kernel is
  -- located and sort based on the index so that partial fusion may
  -- succeed.  We use the last position where one of the kernel
  -- outputs occur.
  let stm_nms :: [[VName]]
stm_nms = (Stm -> [VName]) -> [Stm] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT Type -> [VName]) -> (Stm -> PatT Type) -> Stm -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> PatT Type
forall rep. Stm rep -> Pat rep
stmPat) [Stm]
rem_stms
  [Maybe (FusedKer, KernName, Int)]
kernminds <- [KernName]
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [KernName]
to_fuse_knms ((KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
 -> FusionGM [Maybe (FusedKer, KernName, Int)])
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ \KernName
ker_nm -> do
    FusedKer
ker <- KernName -> FusionGM FusedKer
lookupKernel KernName
ker_nm
    case (VName -> Maybe Int) -> [VName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\VName
out_nm -> ([VName] -> Bool) -> [[VName]] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem VName
out_nm) [[VName]]
stm_nms) (FusedKer -> [VName]
outNames FusedKer
ker) of
      [] -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (FusedKer, KernName, Int)
forall a. Maybe a
Nothing
      [Int]
is -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (FusedKer, KernName, Int)
 -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall a b. (a -> b) -> a -> b
$ (FusedKer, KernName, Int) -> Maybe (FusedKer, KernName, Int)
forall a. a -> Maybe a
Just (FusedKer
ker, KernName
ker_nm, [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum [Int]
is)

  Scope SOACS
scope <- FusionGM (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let kernminds' :: [(FusedKer, KernName, Int)]
kernminds' = ((FusedKer, KernName, Int)
 -> (FusedKer, KernName, Int) -> Ordering)
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy (\(FusedKer
_, KernName
_, Int
i1) (FusedKer
_, KernName
_, Int
i2) -> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
i1 Int
i2) ([(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)])
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ [Maybe (FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (FusedKer, KernName, Int)]
kernminds
      soac_kernel :: FusedKer
soac_kernel = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope

  -- now try to fuse kernels one by one (in a fold); @ok_ind@ is the index of the
  -- kernel until which fusion succeded, and @fused_ker@ is the resulting kernel.
  (Bool
_, Int
ok_ind, Int
_, FusedKer
fused_ker, Names
_) <-
    ((Bool, Int, Int, FusedKer, Names)
 -> (FusedKer, KernName, Int)
 -> FusionGM (Bool, Int, Int, FusedKer, Names))
-> (Bool, Int, Int, FusedKer, Names)
-> [(FusedKer, KernName, Int)]
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      ( \(Bool
cur_ok, Int
n, Int
prev_ind, FusedKer
cur_ker, Names
ufus_nms) (FusedKer
ker, KernName
_ker_nm, Int
stm_ind) -> do
          -- check that we still try fusion and that the intermediate
          -- bindings do not use the results of cur_ker
          let curker_outnms :: [VName]
curker_outnms = FusedKer -> [VName]
outNames FusedKer
cur_ker
              curker_outset :: Names
curker_outset = [VName] -> Names
namesFromList [VName]
curker_outnms
              new_ufus_nms :: Names
new_ufus_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Names -> [VName]
namesToList Names
ufus_nms
              -- disable horizontal fusion in the case when an output array of
              -- producer SOAC is a non-trivially transformed input of the consumer
              out_transf_ok :: Bool
out_transf_ok =
                let ker_inp :: [Input]
ker_inp = SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs (SOAC -> [Input]) -> SOAC -> [Input]
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
                    unfuse1 :: Names
unfuse1 =
                      [VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray [Input]
ker_inp)
                        Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarInput [Input]
ker_inp)
                    unfuse2 :: Names
unfuse2 = Names -> Names -> Names
namesIntersection Names
curker_outset Names
ufus_nms
                 in Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
unfuse1 Names -> Names -> Bool
`namesIntersect` Names
unfuse2
              -- Disable horizontal fusion if consumer has any
              -- output transforms.
              cons_no_out_transf :: Bool
cons_no_out_transf = ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker

          -- check that consumer's lambda body does not use
          -- directly the produced arrays (e.g., see noFusion3.fut).
          let consumer_ok :: Bool
consumer_ok =
                Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
                  Names
curker_outset
                    Names -> Names -> Bool
`namesIntersect` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody (Lambda SOACS -> BodyT SOACS) -> Lambda SOACS -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda (SOAC -> Lambda SOACS) -> SOAC -> Lambda SOACS
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker)

          let interm_stms_ok :: Bool
interm_stms_ok =
                Bool
cur_ok Bool -> Bool -> Bool
&& Bool
consumer_ok Bool -> Bool -> Bool
&& Bool
out_transf_ok Bool -> Bool -> Bool
&& Bool
cons_no_out_transf
                  Bool -> Bool -> Bool
&& (Bool -> Stm -> Bool) -> Bool -> [Stm] -> Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                    ( \Bool
ok Stm
stm ->
                        Bool
ok
                          Bool -> Bool -> Bool
&& Bool -> Bool
not (Names
curker_outset Names -> Names -> Bool
`namesIntersect` Exp -> Names
forall a. FreeIn a => a -> Names
freeIn (Stm -> Exp
forall rep. Stm rep -> Exp rep
stmExp Stm
stm))
                          -- hardwired to False after first fail
                          -- (i) check that the in-between bindings do
                          --     not use the result of current kernel OR
                          Bool -> Bool -> Bool
||
                          --(ii) that the pattern-binding corresponds to
                          --     the result of the consumer kernel; in the
                          --     latter case it means it corresponds to a
                          --     kernel that has been fused in the consumer,
                          --     hence it should be ignored
                          Bool -> Bool
not
                            ( [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
                                [VName]
curker_outnms
                                  [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (Stm -> Pat
forall rep. Stm rep -> Pat rep
stmPat Stm
stm)
                            )
                    )
                    Bool
True
                    (Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
drop (Int
prev_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Stm] -> [Stm]) -> [Stm] -> [Stm]
forall a b. (a -> b) -> a -> b
$ Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
take Int
stm_ind [Stm]
rem_stms)
          if Bool -> Bool
not Bool
interm_stms_ok
            then (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
stm_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
            else do
              Maybe FusedKer
new_ker <-
                Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion
                  Names
ufus_nms
                  (FusedKer -> [VName]
outNames FusedKer
cur_ker)
                  (FusedKer -> SOAC
fsoac FusedKer
cur_ker)
                  (FusedKer -> Names
fusedConsumed FusedKer
cur_ker)
                  FusedKer
ker
              case Maybe FusedKer
new_ker of
                Maybe FusedKer
Nothing -> (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
stm_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
                Just FusedKer
krn ->
                  let krn' :: FusedKer
krn' = FusedKer
krn {kerAux :: StmAux ()
kerAux = StmAux ()
aux StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> FusedKer -> StmAux ()
kerAux FusedKer
krn}
                   in (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
stm_ind, FusedKer
krn', Names
new_ufus_nms)
      )
      (Bool
True, Int
0, Int
0, FusedKer
soac_kernel, Names
infusible_nms)
      [(FusedKer, KernName, Int)]
kernminds'

  -- Find the kernels we have fused into and the name of the last such
  -- kernel (if any).
  let ([FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms', [Int]
_) = [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int]))
-> [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b. (a -> b) -> a -> b
$ Int -> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. Int -> [a] -> [a]
take Int
ok_ind [(FusedKer, KernName, Int)]
kernminds'
      new_kernms :: [KernName]
new_kernms = Int -> [KernName] -> [KernName]
forall a. Int -> [a] -> [a]
drop (Int
ok_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [KernName]
to_fuse_knms'

  (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
ok_ind Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0, [FusedKer
fused_ker], [KernName]
new_kernms, [FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms')
  where
    getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
    getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize SubExp
sz FusedRes
ress =
      ((KernName, FusedKer) -> KernName)
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> [a] -> [b]
map (KernName, FusedKer) -> KernName
forall a b. (a, b) -> a
fst ([(KernName, FusedKer)] -> [KernName])
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> a -> b
$ ((KernName, FusedKer) -> Bool)
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(KernName
_, FusedKer
ker) -> SubExp
sz SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width (FusedKer -> SOAC
fsoac FusedKer
ker)) ([(KernName, FusedKer)] -> [(KernName, FusedKer)])
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ Map KernName FusedKer -> [(KernName, FusedKer)]
forall k a. Map k a -> [(k, a)]
M.toList (Map KernName FusedKer -> [(KernName, FusedKer)])
-> Map KernName FusedKer -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
ress

------------------------------------------------------------------------
------------------------------------------------------------------------
------------------------------------------------------------------------
--- Fusion Gather for EXPRESSIONS and BODIES,                        ---
--- i.e., where work is being done:                                  ---
---    i) bottom-up AbSyn traversal (backward analysis)              ---
---   ii) soacs are fused greedily iff does not duplicate computation---
--- E.g., (y1, y2, y3) = mapT(f, x1, x2[i])                          ---
---       (z1, z2)     = mapT(g1, y1, y2)                            ---
---       (q1, q2)     = mapT(g2, y3, z1, a, y3)                     ---
---       res          = reduce(op, ne, q1, q2, z2, y1, y3)          ---
--- can be fused if y1,y2,y3, z1,z2, q1,q2 are not used elsewhere:   ---
---       res = redomap(op, \(x1,x2i,a)->                            ---
---                             let (y1,y2,y3) = f (x1, x2i)       in---
---                             let (z1,z2)    = g1(y1, y2)        in---
---                             let (q1,q2)    = g2(y3, z1, a, y3) in---
---                             (q1, q2, z2, y1, y3)                 ---
---                     x1, x2[i], a)                                ---
------------------------------------------------------------------------
------------------------------------------------------------------------
------------------------------------------------------------------------

fusionGatherBody :: FusedRes -> Body -> FusionGM FusedRes
fusionGatherBody :: FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
fres (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) =
  FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres (Stms SOACS -> [Stm]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms) Result
res

fusionGatherStms :: FusedRes -> [Stm] -> Result -> FusionGM FusedRes
-- Some forms of do-loops can profitably be considered streamSeqs.  We
-- are careful to ensure that the generated nested loop cannot itself
-- be considered a stream, to avoid infinite recursion.
fusionGatherStms :: FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms
  FusedRes
fres
  (Let (Pat [PatElem]
pes) StmAux (ExpDec SOACS)
stmtp (DoLoop [(FParam SOACS, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
w [(LParam SOACS, VName)]
loop_vars) BodyT SOACS
body) : [Stm]
stms)
  Result
res
    | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Param Type, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars = do
      let ([Param DeclType]
merge_params, [SubExp]
merge_init) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
          ([Param Type]
loop_params, [VName]
loop_arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars
      Param Type
chunk_param <- String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"chunk_size" (Type -> FusionGM (Param Type)) -> Type -> FusionGM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      Param Type
offset_param <- String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"offset" (Type -> FusionGM (Param Type)) -> Type -> FusionGM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
      let offset :: VName
offset = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
offset_param
          chunk_size :: VName
chunk_size = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param

      [Param Type]
acc_params <- [Param DeclType]
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
merge_params ((Param DeclType -> FusionGM (Param Type))
 -> FusionGM [Param Type])
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \Param DeclType
p ->
        String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_outer") (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p)

      [Param Type]
chunked_params <- [(Param Type, VName)]
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars (((Param Type, VName) -> FusionGM (Param Type))
 -> FusionGM [Param Type])
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
        String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam
          (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_chunk")
          (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Futhark.Var VName
chunk_size)

      let lam_params :: [Param Type]
lam_params = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param] [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
chunked_params

      BodyT SOACS
lam_body <- Builder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
        Scope SOACS
-> Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
lam_params) (Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
          let merge' :: [(Param DeclType, SubExp)]
merge' = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
merge_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
acc_params
          VName
j <- String -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"j"
          BodyT SOACS
loop_body <- Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
            [(Param Type, Param Type)]
-> ((Param Type, Param Type)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
loop_params [Param Type]
chunked_params) (((Param Type, Param Type)
  -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((Param Type, Param Type)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, Param Type
a_p) ->
              [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
a_p) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                    Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
a_p) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Futhark.Var VName
j]
            [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
i] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
j)
            BodyT SOACS -> Builder SOACS (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SOACS
body
          [BuilderT
   SOACS
   (State VNameSource)
   (Exp (Rep (BuilderT SOACS (State VNameSource))))]
-> BuilderT
     SOACS
     (State VNameSource)
     (Body (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
            [ Exp -> BuilderT SOACS (State VNameSource) Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> BuilderT SOACS (State VNameSource) Exp)
-> Exp -> BuilderT SOACS (State VNameSource) Exp
forall a b. (a -> b) -> a -> b
$
                [(FParam SOACS, SubExp)] -> LoopForm SOACS -> BodyT SOACS -> Exp
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge' (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
j IntType
it (VName -> SubExp
Futhark.Var VName
chunk_size) []) BodyT SOACS
loop_body,
              Exp -> BuilderT SOACS (State VNameSource) Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> BuilderT SOACS (State VNameSource) Exp)
-> Exp -> BuilderT SOACS (State VNameSource) Exp
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
chunk_size)
            ]
      let lam :: Lambda SOACS
lam =
            Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
              { lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
lam_params,
                lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
lam_body,
                lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param]
              }
          stream :: SOAC SOACS
stream = SubExp
-> [VName]
-> StreamForm SOACS
-> [SubExp]
-> Lambda SOACS
-> SOAC SOACS
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.Stream SubExp
w [VName]
loop_arrs StreamForm SOACS
forall rep. StreamForm rep
Sequential ([SubExp]
merge_init [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
it Integer
0]) Lambda SOACS
lam

      -- It is important that the (discarded) final-offset is not the
      -- first element in the pattern, as we use the first element to
      -- identify the SOAC in the second phase of fusion.
      VName
discard <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"discard"
      let discard_pe :: PatElemT Type
discard_pe = VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
discard (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

      FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms
        FusedRes
fres
        (Pat -> StmAux (ExpDec SOACS) -> Exp -> Stm
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type]
[PatElem]
pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type
discard_pe])) StmAux (ExpDec SOACS)
stmtp (Op SOACS -> Exp
forall rep. Op rep -> ExpT rep
Op Op SOACS
SOAC SOACS
stream) Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms)
        Result
res
fusionGatherStms FusedRes
fres (stm :: Stm
stm@(Let Pat
pat StmAux (ExpDec SOACS)
_ Exp
e) : [Stm]
stms) Result
res = do
  Either NotSOAC SOAC
maybesoac <- Exp -> FusionGM (Either NotSOAC SOAC)
forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
SOAC.fromExp Exp
e
  case Either NotSOAC SOAC
maybesoac of
    Right soac :: SOAC
soac@(SOAC.Scatter SubExp
_len Lambda SOACS
lam [Input]
_ivs [(Shape, Int, VName)]
_as) -> do
      -- We put the variables produced by Scatter into the infusible
      -- set to force horizontal fusion.  It is not possible to
      -- producer/consumer-fuse Scatter anyway.
      FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
      FusedRes
fres'' <- FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
      FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates FusedRes
fres'' Exp
e
    Right soac :: SOAC
soac@(SOAC.Hist SubExp
_ [HistOp SOACS]
_ Lambda SOACS
lam [Input]
_) -> do
      -- We put the variables produced by Hist into the infusible
      -- set to force horizontal fusion.  It is not possible to
      -- producer/consumer-fuse Hist anyway.
      FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
      FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
    Right soac :: SOAC
soac@(SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
_) ->
      SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac ((Scan SOACS -> Lambda SOACS) -> [Scan SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda [Scan SOACS]
scans [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> Lambda SOACS) -> [Reduce SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda [Reduce SOACS]
reds [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> [Lambda SOACS
map_lam]) ([SubExp] -> FusionGM FusedRes) -> [SubExp] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
        (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds
    Right soac :: SOAC
soac@(SOAC.Stream SubExp
_ StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
_) -> do
      -- a redomap does not neccessarily start a new kernel, e.g.,
      -- @let a= reduce(+,0,A) in ... stms ... in let B = map(f,A)@
      -- can be fused into a redomap that replaces the @map@, if @a@
      -- and @B@ are defined in the same scope and @stms@ does not uses @a@.
      -- a redomap always starts a new kernel
      let lambdas :: [Lambda SOACS]
lambdas = case StreamForm SOACS
form of
            Parallel StreamOrd
_ Commutativity
_ Lambda SOACS
lout -> [Lambda SOACS
lout, Lambda SOACS
lam]
            StreamForm SOACS
Sequential -> [Lambda SOACS
lam]
      SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes
    Either NotSOAC SOAC
_
      | Pat [PatElem
pe] <- Pat
pat,
        Just (VName
src, ArrayTransform
trns) <- Certs -> Exp -> Maybe (VName, ArrayTransform)
forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
SOAC.transformFromExp (Stm -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm
stm) Exp
e ->
        PatElem
-> VName
-> ArrayTransform
-> FusionGM FusedRes
-> FusionGM FusedRes
forall a.
PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElem
pe VName
src ArrayTransform
trns (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
stms Result
res
      | Bool
otherwise -> do
        let pat_vars :: [Exp]
pat_vars = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> (VName -> BasicOp) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
        FusedRes
bres <- Pat -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPat Pat
pat Exp
e (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
stms Result
res
        FusedRes
bres' <- FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates FusedRes
bres Exp
e
        (FusedRes -> Exp -> FusionGM FusedRes)
-> FusedRes -> [Exp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp FusedRes
bres' (Exp
e Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
pat_vars)
  where
    aux :: StmAux (ExpDec SOACS)
aux = Stm -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm
stm
    rem_stms :: [Stm]
rem_stms = Stm
stm Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms
    consumed :: Names
consumed = Exp (Aliases SOACS) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp (Exp (Aliases SOACS) -> Names) -> Exp (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Exp -> Exp (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Exp rep -> Exp (Aliases rep)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp
e

    reduceLike :: SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes = do
      (Names
used_lam, FusedRes
lres) <- ((Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes))
-> (Names, FusedRes)
-> [Lambda SOACS]
-> FusionGM (Names, FusedRes)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) [Lambda SOACS]
lambdas
      FusedRes
bres <- Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pat
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
lres [Stm]
stms Result
res
      FusedRes
bres' <- (FusedRes -> SubExp -> FusionGM FusedRes)
-> FusedRes -> [SubExp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
bres [SubExp]
nes
      Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
      [Stm]
-> Names
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_stms Names
used_lam FusedRes
bres' (Pat
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')

    mapLike :: FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lambda = do
      FusedRes
bres <- Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pat
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres' [Stm]
stms Result
res
      (Names
used_lam, FusedRes
blres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
bres) Lambda SOACS
lambda
      Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
      [Stm]
-> Names
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_stms Names
used_lam FusedRes
blres (Pat
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')
fusionGatherStms FusedRes
fres [] Result
res =
  (FusedRes -> Exp -> FusionGM FusedRes)
-> FusedRes -> [Exp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp FusedRes
fres ([Exp] -> FusionGM FusedRes) -> [Exp] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp) -> Result -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> (SubExpRes -> BasicOp) -> SubExpRes -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res

fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp FusedRes
fres (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form BodyT SOACS
loop_body) = do
  FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
  let form_idents :: [Ident]
form_idents =
        case LoopForm SOACS
form of
          ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
            VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
it)) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
: ((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
          WhileLoop {} -> []

  FusedRes
new_res <-
    [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Ident]
form_idents [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ ((Param DeclType, SubExp) -> Ident)
-> [(Param DeclType, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param DeclType -> Ident)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
      FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
loop_body
  -- make the inpArr infusible, so that they
  -- cannot be fused from outside the loop:
  let ([VName]
inp_arrs, [Set KernName]
_) = [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Set KernName)] -> ([VName], [Set KernName]))
-> [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [(VName, Set KernName)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName (Set KernName) -> [(VName, Set KernName)])
-> Map VName (Set KernName) -> [(VName, Set KernName)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
  let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_arrs)}
  -- merge new_res with fres'
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres'
fusionGatherExp FusedRes
fres (If SubExp
cond BodyT SOACS
e_then BodyT SOACS
e_else IfDec (BranchType SOACS)
_) = do
  FusedRes
then_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_then
  FusedRes
else_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_else
  let both_res :: FusedRes
both_res = FusedRes
then_res FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
else_res
  FusedRes
fres' <- FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres SubExp
cond
  FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
fres' FusedRes
both_res
fusionGatherExp FusedRes
fres (WithAcc [WithAccInput SOACS]
inps Lambda SOACS
lam) = do
  (Names
_, FusedRes
fres') <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) Lambda SOACS
lam
  FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres' (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn [WithAccInput SOACS]
inps

-----------------------------------------------------------------------------------
--- Errors: all SOACs, (because normalization ensures they appear
--- directly in let exp, i.e., let x = e)
-----------------------------------------------------------------------------------

fusionGatherExp FusedRes
_ (Op Futhark.Screma {}) = String -> FusionGM FusedRes
errorIllegal String
"screma"
fusionGatherExp FusedRes
_ (Op Futhark.Scatter {}) = String -> FusionGM FusedRes
errorIllegal String
"write"
-----------------------------------
---- Generic Traversal         ----
-----------------------------------

fusionGatherExp FusedRes
fres Exp
e = FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Exp -> Names
forall a. FreeIn a => a -> Names
freeIn Exp
e

fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres (Var VName
idd) = FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
idd
fusionGatherSubExp FusedRes
fres SubExp
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres

addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres = (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres ([VName] -> FusionGM FusedRes)
-> (Names -> [VName]) -> Names -> FusionGM FusedRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
name = do
  Maybe Input
trns <- (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name
  let name' :: VName
name' = case Maybe Input
trns of
        Maybe Input
Nothing -> VName
name
        Just (SOAC.Input ArrayTransforms
_ VName
orig Type
_) -> VName
orig
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres {infusible :: Names
infusible = VName -> Names
oneName VName
name' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
fres}

-- Lambdas create a new scope.  Disallow fusing from outside lambda by
-- adding inp_arrs to the infusible set.
fusionGatherLam :: (Names, FusedRes) -> Lambda -> FusionGM (Names, FusedRes)
fusionGatherLam :: (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
u_set, FusedRes
fres) (Lambda [LParam SOACS]
idds BodyT SOACS
body [Type]
_) = do
  FusedRes
new_res <- [Param Type] -> FusionGM FusedRes -> FusionGM FusedRes
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
idds (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
body
  -- make the inpArr infusible, so that they
  -- cannot be fused from outside the lambda:
  let inp_arrs :: Names
inp_arrs = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
  let unfus :: Names
unfus = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
inp_arrs
  [VName]
stms <- (FusionGEnv -> [VName]) -> FusionGM [VName]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> [VName]) -> FusionGM [VName])
-> (FusionGEnv -> [VName]) -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName VarEntry -> [VName])
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
  let unfus' :: Names
unfus' = Names
unfus Names -> Names -> Names
`namesIntersection` [VName] -> Names
namesFromList [VName]
stms
  -- merge fres with new_res'
  let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = Names
unfus'}
  -- merge new_res with fres'
  (Names, FusedRes) -> FusionGM (Names, FusedRes)
forall (m :: * -> *) a. Monad m => a -> m a
return (Names
u_set Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus', FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres)

-------------------------------------------------------------
-------------------------------------------------------------
--- FINALLY, Substitute the kernels in function
-------------------------------------------------------------
-------------------------------------------------------------

fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms
  | Just (Let Pat
pat StmAux (ExpDec SOACS)
aux Exp
e, Stms SOACS
stms') <- Stms SOACS -> Maybe (Stm, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms = do
    Stms SOACS
stms'' <- Pat -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. Pat -> FusionGM a -> FusionGM a
bindingPat Pat
pat (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms'
    Stms SOACS
soac_stms <- Pat -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC Pat
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp
e
    Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> FusionGM (Stms SOACS))
-> Stms SOACS -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
soac_stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms''
  | Bool
otherwise =
    Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty

fuseInBody :: Body -> FusionGM Body
fuseInBody :: BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) =
  BodyDec SOACS -> Stms SOACS -> Result -> BodyT SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () (Stms SOACS -> Result -> BodyT SOACS)
-> FusionGM (Stms SOACS) -> FusionGM (Result -> BodyT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms FusionGM (Result -> BodyT SOACS)
-> FusionGM Result -> FusionGM (BodyT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> FusionGM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

fuseInExp :: Exp -> FusionGM Exp
-- Handle loop specially because we need to bind the types of the
-- merge variables.
fuseInExp :: Exp -> FusionGM Exp
fuseInExp (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form BodyT SOACS
loopbody) =
  [(Ident, Names)] -> FusionGM Exp -> FusionGM Exp
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
form_idents ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) (FusionGM Exp -> FusionGM Exp) -> FusionGM Exp -> FusionGM Exp
forall a b. (a -> b) -> a -> b
$
    [Param DeclType] -> FusionGM Exp -> FusionGM Exp
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) (FusionGM Exp -> FusionGM Exp) -> FusionGM Exp -> FusionGM Exp
forall a b. (a -> b) -> a -> b
$
      [(FParam SOACS, SubExp)] -> LoopForm SOACS -> BodyT SOACS -> Exp
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form (BodyT SOACS -> Exp) -> FusionGM (BodyT SOACS) -> FusionGM Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
loopbody
  where
    form_idents :: [Ident]
form_idents = case LoopForm SOACS
form of
      WhileLoop {} -> []
      ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
        VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
:
        ((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
fuseInExp Exp
e = Mapper SOACS SOACS FusionGM -> Exp -> FusionGM Exp
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS FusionGM
fuseIn Exp
e

fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn =
  Mapper SOACS SOACS FusionGM
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
    { mapOnBody :: Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
mapOnBody = (BodyT SOACS -> FusionGM (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
forall a b. a -> b -> a
const BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody,
      mapOnOp :: Op SOACS -> FusionGM (Op SOACS)
mapOnOp = SOACMapper SOACS SOACS FusionGM
-> SOAC SOACS -> FusionGM (SOAC SOACS)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper Any Any FusionGM
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda}
    }

fuseInLambda :: Lambda -> FusionGM Lambda
fuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rtp) = do
  BodyT SOACS
body' <- [Param Type] -> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
params (FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
body
  Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda SOACS -> FusionGM (Lambda SOACS))
-> Lambda SOACS -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam SOACS]
params BodyT SOACS
body' [Type]
rtp

replaceSOAC :: Pat -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC :: Pat -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC (Pat []) StmAux ()
_ Exp
_ = Stms SOACS -> FusionGM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
forall a. Monoid a => a
mempty
replaceSOAC pat :: Pat
pat@(Pat (PatElem
patElem : [PatElem]
_)) StmAux ()
aux Exp
e = do
  FusedRes
fres <- (FusionGEnv -> FusedRes) -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks FusionGEnv -> FusedRes
fusedRes
  let pat_nm :: VName
pat_nm = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
patElem
      names :: [Ident]
names = PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
pat
  case VName -> Map VName KernName -> Maybe KernName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_nm (FusedRes -> Map VName KernName
outArr FusedRes
fres) of
    Maybe KernName
Nothing ->
      Stm -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm -> Stms SOACS) -> (Exp -> Stm) -> Exp -> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> StmAux (ExpDec SOACS) -> Exp -> Stm
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp -> Stms SOACS) -> FusionGM Exp -> FusionGM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> FusionGM Exp
fuseInExp Exp
e
    Just KernName
knm ->
      case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
knm (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres) of
        Maybe FusedKer
Nothing ->
          Error -> FusionGM (Stms SOACS)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM (Stms SOACS)) -> Error -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
            String -> Error
Error
              ( String
"In Fusion.hs, replaceSOAC, outArr in ker_name "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"which is not in Res: "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (KernName -> VName
unKernName KernName
knm)
              )
        Just FusedKer
ker -> do
          Bool -> FusionGM () -> FusionGM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
fusedVars FusedKer
ker) (FusionGM () -> FusionGM ()) -> FusionGM () -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
            Error -> FusionGM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM ()) -> Error -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
              String -> Error
Error
                ( String
"In Fusion.hs, replaceSOAC, unfused kernel "
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"still in result: "
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Ident] -> String
forall a. Pretty a => a -> String
pretty [Ident]
names
                )
          StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux (FusedKer -> [VName]
outNames FusedKer
ker) FusedKer
ker

insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux [VName]
names FusedKer
ker = do
  SOAC
new_soac' <- SOAC -> FusionGM SOAC
finaliseSOAC (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
  BuilderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (BuilderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS))
-> BuilderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
    SOAC SOACS
f_soac <- SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
     SOACS
     (State VNameSource)
     (SOAC (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
SOAC.toSOAC SOAC (Rep (BuilderT SOACS (State VNameSource)))
SOAC
new_soac'
    -- The fused kernel may consume more than the original SOACs (see
    -- issue #224).  We insert copy expressions to fix it.
    SOAC SOACS
f_soac' <- Names
-> SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed (FusedKer -> Names
fusedConsumed FusedKer
ker) (SOAC (Aliases SOACS)
 -> BuilderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable -> SOAC SOACS -> OpWithAliases (SOAC SOACS)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
forall a. Monoid a => a
mempty SOAC SOACS
f_soac
    [Ident]
validents <- (String -> Type -> BuilderT SOACS (State VNameSource) Ident)
-> [String] -> [Type] -> BuilderT SOACS (State VNameSource) [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> BuilderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
baseString [VName]
names) ([Type] -> BuilderT SOACS (State VNameSource) [Ident])
-> [Type] -> BuilderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall rep. SOAC rep -> [Type]
SOAC.typeOf SOAC
new_soac'
    StmAux ()
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (FusedKer -> StmAux ()
kerAux FusedKer
ker StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux) (BuilderT SOACS (State VNameSource) ()
 -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (BuilderT SOACS (State VNameSource)))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([Ident] -> PatT Type
basicPat [Ident]
validents) (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp
forall rep. Op rep -> ExpT rep
Op Op SOACS
SOAC SOACS
f_soac'
    ArrayTransforms
-> [VName] -> [Ident] -> BuilderT SOACS (State VNameSource) ()
transformOutput (FusedKer -> ArrayTransforms
outputTransform FusedKer
ker) [VName]
names [Ident]
validents

-- | Perform simplification and fusion inside the lambda(s) of a SOAC.
finaliseSOAC :: SOAC.SOAC SOACS -> FusionGM (SOAC.SOAC SOACS)
finaliseSOAC :: SOAC -> FusionGM SOAC
finaliseSOAC SOAC
new_soac =
  case SOAC
new_soac of
    SOAC.Screma SubExp
w (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
arrs -> do
      [Scan SOACS]
scans' <- [Scan SOACS]
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS])
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
scan_nes) -> do
        Lambda SOACS
scan_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
scan_lam
        Scan SOACS -> FusionGM (Scan SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scan SOACS -> FusionGM (Scan SOACS))
-> Scan SOACS -> FusionGM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam' [SubExp]
scan_nes

      [Reduce SOACS]
reds' <- [Reduce SOACS]
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> FusionGM (Reduce SOACS))
 -> FusionGM [Reduce SOACS])
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes) -> do
        Lambda SOACS
red_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
red_lam
        Reduce SOACS -> FusionGM (Reduce SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Reduce SOACS -> FusionGM (Reduce SOACS))
-> Reduce SOACS -> FusionGM (Reduce SOACS)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_lam' [SubExp]
red_nes

      Lambda SOACS
map_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
map_lam

      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam') [Input]
arrs
    SOAC.Scatter SubExp
w Lambda SOACS
lam [Input]
inps [(Shape, Int, VName)]
dests -> do
      Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' [Input]
inps [(Shape, Int, VName)]
dests
    SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam [Input]
arrs -> do
      Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam' [Input]
arrs
    SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
inps -> do
      Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall rep.
SubExp
-> StreamForm rep -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam' [SubExp]
nes [Input]
inps

simplifyAndFuseInLambda :: Lambda -> FusionGM Lambda
simplifyAndFuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam = do
  Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
  (Names
_, FusedRes
nfres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
mkFreshFusionRes) Lambda SOACS
lam'
  let nfres' :: FusedRes
nfres' = FusedRes -> FusedRes
cleanFusionResult FusedRes
nfres
  FusedRes -> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
nfres' (FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS))
-> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda Lambda SOACS
lam'

copyNewlyConsumed ::
  Names ->
  Futhark.SOAC (Aliases.Aliases SOACS) ->
  Builder SOACS (Futhark.SOAC SOACS)
copyNewlyConsumed :: Names
-> SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed Names
was_consumed SOAC (Aliases SOACS)
soac =
  case SOAC (Aliases SOACS)
soac of
    Futhark.Screma SubExp
w [VName]
arrs (Futhark.ScremaForm [Scan (Aliases SOACS)]
scans [Reduce (Aliases SOACS)]
reds Lambda (Aliases SOACS)
map_lam) -> do
      -- Copy any arrays that are consumed now, but were not in the
      -- constituents.
      [VName]
arrs' <- (VName -> BuilderT SOACS (State VNameSource) VName)
-> [VName] -> BuilderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BuilderT SOACS (State VNameSource) VName
copyConsumedArr [VName]
arrs
      -- Any consumed free variables will have to be copied inside the
      -- lambda, and we have to substitute the name of the copy for
      -- the original.
      Lambda SOACS
map_lam' <- Lambda (Aliases (Rep (BuilderT SOACS (State VNameSource))))
-> BuilderT
     SOACS
     (State VNameSource)
     (Lambda (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(CanBeAliased (Op (Rep m)), MonadBuilder m, Buildable (Rep m)) =>
Lambda (Aliases (Rep m)) -> m (Lambda (Rep m))
copyFreeInLambda Lambda (Aliases (Rep (BuilderT SOACS (State VNameSource))))
Lambda (Aliases SOACS)
map_lam

      let scans' :: [Scan SOACS]
scans' =
            (Scan (Aliases SOACS) -> Scan SOACS)
-> [Scan (Aliases SOACS)] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
              ( \Scan (Aliases SOACS)
scan ->
                  Scan (Aliases SOACS)
scan
                    { scanLambda :: Lambda SOACS
scanLambda =
                        Lambda (Aliases SOACS) -> Lambda SOACS
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
Aliases.removeLambdaAliases
                          (Scan (Aliases SOACS) -> Lambda (Aliases SOACS)
forall rep. Scan rep -> Lambda rep
scanLambda Scan (Aliases SOACS)
scan)
                    }
              )
              [Scan (Aliases SOACS)]
scans

      let reds' :: [Reduce SOACS]
reds' =
            (Reduce (Aliases SOACS) -> Reduce SOACS)
-> [Reduce (Aliases SOACS)] -> [Reduce SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
              ( \Reduce (Aliases SOACS)
red ->
                  Reduce (Aliases SOACS)
red
                    { redLambda :: Lambda SOACS
redLambda =
                        Lambda (Aliases SOACS) -> Lambda SOACS
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
Aliases.removeLambdaAliases
                          (Reduce (Aliases SOACS) -> Lambda (Aliases SOACS)
forall rep. Reduce rep -> Lambda rep
redLambda Reduce (Aliases SOACS)
red)
                    }
              )
              [Reduce (Aliases SOACS)]
reds

      SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
Futhark.ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam'
    SOAC (Aliases SOACS)
_ -> SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SOAC SOACS) -> SOAC SOACS
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SOAC SOACS)
SOAC (Aliases SOACS)
soac
  where
    consumed :: Names
consumed = SOAC (Aliases SOACS) -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SOAC (Aliases SOACS)
soac
    newly_consumed :: Names
newly_consumed = Names
consumed Names -> Names -> Names
`namesSubtract` Names
was_consumed

    copyConsumedArr :: VName -> BuilderT SOACS (State VNameSource) VName
copyConsumedArr VName
a
      | VName
a VName -> Names -> Bool
`nameIn` Names
newly_consumed =
        String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) VName)
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
a
      | Bool
otherwise = VName -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
a

    copyFreeInLambda :: Lambda (Aliases (Rep m)) -> m (Lambda (Rep m))
copyFreeInLambda Lambda (Aliases (Rep m))
lam = do
      let free_consumed :: Names
free_consumed =
            Lambda (Aliases (Rep m)) -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda (Aliases (Rep m))
lam
              Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Param (LParamInfo (Rep m)) -> VName)
-> [Param (LParamInfo (Rep m))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo (Rep m))] -> [VName])
-> [Param (LParamInfo (Rep m))] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases (Rep m)) -> [LParam (Aliases (Rep m))]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Aliases (Rep m))
lam)
      (Seq (Stm (Rep m))
stms, Map VName VName
subst) <-
        ((Seq (Stm (Rep m)), Map VName VName)
 -> VName -> m (Seq (Stm (Rep m)), Map VName VName))
-> (Seq (Stm (Rep m)), Map VName VName)
-> [VName]
-> m (Seq (Stm (Rep m)), Map VName VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Seq (Stm (Rep m)), Map VName VName)
-> VName -> m (Seq (Stm (Rep m)), Map VName VName)
forall (m :: * -> *).
MonadBuilder m =>
(Stms (Rep m), Map VName VName)
-> VName -> m (Stms (Rep m), Map VName VName)
copyFree (Seq (Stm (Rep m))
forall a. Monoid a => a
mempty, Map VName VName
forall a. Monoid a => a
mempty) ([VName] -> m (Seq (Stm (Rep m)), Map VName VName))
-> [VName] -> m (Seq (Stm (Rep m)), Map VName VName)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free_consumed
      let lam' :: Lambda (Rep m)
lam' = Lambda (Aliases (Rep m)) -> Lambda (Rep m)
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
Aliases.removeLambdaAliases Lambda (Aliases (Rep m))
lam
      Lambda (Rep m) -> m (Lambda (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Rep m) -> m (Lambda (Rep m)))
-> Lambda (Rep m) -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$
        if Seq (Stm (Rep m)) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm (Rep m))
stms
          then Lambda (Rep m)
lam'
          else
            Lambda (Rep m)
lam'
              { lambdaBody :: BodyT (Rep m)
lambdaBody =
                  Seq (Stm (Rep m)) -> BodyT (Rep m) -> BodyT (Rep m)
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Seq (Stm (Rep m))
stms (BodyT (Rep m) -> BodyT (Rep m)) -> BodyT (Rep m) -> BodyT (Rep m)
forall a b. (a -> b) -> a -> b
$
                    Map VName VName -> BodyT (Rep m) -> BodyT (Rep m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (BodyT (Rep m) -> BodyT (Rep m)) -> BodyT (Rep m) -> BodyT (Rep m)
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> BodyT (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
lam'
              }

    copyFree :: (Stms (Rep m), Map VName VName)
-> VName -> m (Stms (Rep m), Map VName VName)
copyFree (Stms (Rep m)
stms, Map VName VName
subst) VName
v = do
      VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy"
      Stm (Rep m)
copy <- [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName
v_copy] (Exp (Rep m) -> m (Stm (Rep m))) -> Exp (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      (Stms (Rep m), Map VName VName)
-> m (Stms (Rep m), Map VName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Rep m) -> Stms (Rep m)
forall rep. Stm rep -> Stms rep
oneStm Stm (Rep m)
copy Stms (Rep m) -> Stms (Rep m) -> Stms (Rep m)
forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
stms, VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v_copy Map VName VName
subst)

---------------------------------------------------
---------------------------------------------------
---- HELPERS
---------------------------------------------------
---------------------------------------------------

-- | Get a new fusion result, i.e., for when entering a new scope,
--   e.g., a new lambda or a new loop.
mkFreshFusionRes :: FusedRes
mkFreshFusionRes :: FusedRes
mkFreshFusionRes =
  FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
    { rsucc :: Bool
rsucc = Bool
False,
      outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
      inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
      infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
      kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
    }

mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
res1 FusedRes
res2 = do
  let ufus_mres :: Names
ufus_mres = FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2
  [VName]
inp_both <- [VName] -> FusionGM [VName]
expandSoacInpArr ([VName] -> FusionGM [VName]) -> [VName] -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1 Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.intersection` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2
  let m_unfus :: Names
m_unfus = Names
ufus_mres Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_both)
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
    Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      (FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
      (FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
      ((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
      Names
m_unfus
      (FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)

-- | The expression arguments are supposed to be array-type exps.
--   Returns a tuple, in which the arrays that are vars are in the
--   first element of the tuple, and the one which are indexed or
--   transposes (or otherwise transformed) should be in the second.
--
--   E.g., for expression `mapT(f, a, b[i])', the result should be
--   `([a],[b])'
getIdentArr :: [SOAC.Input] -> ([VName], [VName])
getIdentArr :: [Input] -> ([VName], [VName])
getIdentArr = (([VName], [VName]) -> Input -> ([VName], [VName]))
-> ([VName], [VName]) -> [Input] -> ([VName], [VName])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([], [])
  where
    comb :: ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([VName]
vs, [VName]
os) (SOAC.Input ArrayTransforms
ts VName
idd Type
_)
      | ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = (VName
idd VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs, [VName]
os)
    comb ([VName]
vs, [VName]
os) Input
inp =
      ([VName]
vs, Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
os)

cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult FusedRes
fres =
  let newks :: Map KernName FusedKer
newks = (FusedKer -> Bool)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not (Bool -> Bool) -> (FusedKer -> Bool) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> (FusedKer -> [VName]) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [VName]
fusedVars) (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres)
      newoa :: Map VName KernName
newoa = (KernName -> Bool) -> Map VName KernName -> Map VName KernName
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks) (FusedRes -> Map VName KernName
outArr FusedRes
fres)
      newia :: Map VName (Set KernName)
newia = (Set KernName -> Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((KernName -> Bool) -> Set KernName -> Set KernName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks)) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
fres)
   in FusedRes
fres {outArr :: Map VName KernName
outArr = Map VName KernName
newoa, inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
newia, kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
newks}

--------------
--- Errors ---
--------------

errorIllegal :: String -> FusionGM FusedRes
errorIllegal :: String -> FusionGM FusedRes
errorIllegal String
soac_name =
  Error -> FusionGM FusedRes
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedRes) -> Error -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
    String -> Error
Error
      (String
"In Fusion.hs, soac " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
soac_name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" appears illegally in pgm!")