{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion (fuseSOACs) where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import qualified Futhark.IR.Aliases as Aliases
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.IR.SOACS.Simplify
import Futhark.Optimise.Fusion.LoopKernel
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (maxinum)
data VarEntry
= IsArray VName (NameInfo SOACS) Names SOAC.Input
| IsNotArray (NameInfo SOACS)
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType (IsArray VName
_ NameInfo SOACS
dec Names
_ Input
_) =
NameInfo SOACS
dec
varEntryType (IsNotArray NameInfo SOACS
dec) =
NameInfo SOACS
dec
varEntryAliases :: VarEntry -> Names
varEntryAliases :: VarEntry -> Names
varEntryAliases (IsArray VName
_ NameInfo SOACS
_ Names
x Input
_) = Names
x
varEntryAliases VarEntry
_ = Names
forall a. Monoid a => a
mempty
data FusionGEnv = FusionGEnv
{
FusionGEnv -> Map VName [VName]
soacs :: M.Map VName [VName],
FusionGEnv -> Map VName VarEntry
varsInScope :: M.Map VName VarEntry,
FusionGEnv -> FusedRes
fusedRes :: FusedRes
}
lookupArr :: VName -> FusionGEnv -> Maybe SOAC.Input
lookupArr :: VName -> FusionGEnv -> Maybe Input
lookupArr VName
v FusionGEnv
env = VarEntry -> Maybe Input
asArray (VarEntry -> Maybe Input) -> Maybe VarEntry -> Maybe Input
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
where
asArray :: VarEntry -> Maybe Input
asArray (IsArray VName
_ NameInfo SOACS
_ Names
_ Input
input) = Input -> Maybe Input
forall a. a -> Maybe a
Just Input
input
asArray IsNotArray {} = Maybe Input
forall a. Maybe a
Nothing
newtype Error = Error String
instance Show Error where
show :: Error -> String
show (Error String
msg) = String
"Fusion error:\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg
newtype FusionGM a = FusionGM (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a)
deriving
( Applicative FusionGM
a -> FusionGM a
Applicative FusionGM
-> (forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a. a -> FusionGM a)
-> Monad FusionGM
FusionGM a -> (a -> FusionGM b) -> FusionGM b
FusionGM a -> FusionGM b -> FusionGM b
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> FusionGM a
$creturn :: forall a. a -> FusionGM a
>> :: FusionGM a -> FusionGM b -> FusionGM b
$c>> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
>>= :: FusionGM a -> (a -> FusionGM b) -> FusionGM b
$c>>= :: forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
$cp1Monad :: Applicative FusionGM
Monad,
Functor FusionGM
a -> FusionGM a
Functor FusionGM
-> (forall a. a -> FusionGM a)
-> (forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM a)
-> Applicative FusionGM
FusionGM a -> FusionGM b -> FusionGM b
FusionGM a -> FusionGM b -> FusionGM a
FusionGM (a -> b) -> FusionGM a -> FusionGM b
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: FusionGM a -> FusionGM b -> FusionGM a
$c<* :: forall a b. FusionGM a -> FusionGM b -> FusionGM a
*> :: FusionGM a -> FusionGM b -> FusionGM b
$c*> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
liftA2 :: (a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
<*> :: FusionGM (a -> b) -> FusionGM a -> FusionGM b
$c<*> :: forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
pure :: a -> FusionGM a
$cpure :: forall a. a -> FusionGM a
$cp1Applicative :: Functor FusionGM
Applicative,
a -> FusionGM b -> FusionGM a
(a -> b) -> FusionGM a -> FusionGM b
(forall a b. (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b. a -> FusionGM b -> FusionGM a) -> Functor FusionGM
forall a b. a -> FusionGM b -> FusionGM a
forall a b. (a -> b) -> FusionGM a -> FusionGM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> FusionGM b -> FusionGM a
$c<$ :: forall a b. a -> FusionGM b -> FusionGM a
fmap :: (a -> b) -> FusionGM a -> FusionGM b
$cfmap :: forall a b. (a -> b) -> FusionGM a -> FusionGM b
Functor,
MonadError Error,
MonadState VNameSource,
MonadReader FusionGEnv
)
instance MonadFreshNames FusionGM where
getNameSource :: FusionGM VNameSource
getNameSource = FusionGM VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
putNameSource :: VNameSource -> FusionGM ()
putNameSource = VNameSource -> FusionGM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
instance HasScope SOACS FusionGM where
askScope :: FusionGM (Scope SOACS)
askScope = (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS))
-> (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> Scope SOACS
forall k. Map k VarEntry -> Map k (NameInfo SOACS)
toScope (Map VName VarEntry -> Scope SOACS)
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
where
toScope :: Map k VarEntry -> Map k (NameInfo SOACS)
toScope = (VarEntry -> NameInfo SOACS)
-> Map k VarEntry -> Map k (NameInfo SOACS)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry -> NameInfo SOACS
varEntryType
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (Ident VName
name Type
t, Names
aliases) =
FusionGEnv
env {varsInScope :: Map VName VarEntry
varsInScope = VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry
entry (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env}
where
entry :: VarEntry
entry = case Type
t of
Array {} -> VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
name (LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
t) Names
aliases' (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$ Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
Type
_ -> NameInfo SOACS -> VarEntry
IsNotArray (NameInfo SOACS -> VarEntry) -> NameInfo SOACS -> VarEntry
forall a b. (a -> b) -> a -> b
$ LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
t
expand :: VName -> Names
expand = Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases (Maybe VarEntry -> Names)
-> (VName -> Maybe VarEntry) -> VName -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> VName -> Maybe VarEntry
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
aliases' :: Names
aliases' = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expand ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases)
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars = (FusionGEnv -> (Ident, Names) -> FusionGEnv)
-> FusionGEnv -> [(Ident, Names)] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar
binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
vs = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (FusionGEnv -> [(Ident, Names)] -> FusionGEnv
`bindVars` [(Ident, Names)]
vs)
gatherStmPat :: Pat -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPat :: Pat -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPat Pat
pat Exp
e = [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes)
-> [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
idents [Names]
aliases
where
idents :: [Ident]
idents = PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
pat
aliases :: [Names]
aliases = Exp (Aliases SOACS) -> [Names]
forall rep. Aliased rep => Exp rep -> [Names]
expAliases (AliasTable -> Exp -> Exp (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Exp rep -> Exp (Aliases rep)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp
e)
bindingPat :: Pat -> FusionGM a -> FusionGM a
bindingPat :: Pat -> FusionGM a -> FusionGM a
bindingPat = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> (PatT Type -> [(Ident, Names)])
-> PatT Type
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> (PatT Type -> [Ident]) -> PatT Type -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents
bindingParams :: Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams :: [Param t] -> FusionGM a -> FusionGM a
bindingParams = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> ([Param t] -> [(Ident, Names)])
-> [Param t]
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> ([Param t] -> [Ident]) -> [Param t] -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param t -> Ident) -> [Param t] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param t -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
faml FusionGEnv
env (Ident VName
nm Type
t) =
FusionGEnv
env
{ soacs :: Map VName [VName]
soacs = VName -> [VName] -> Map VName [VName] -> Map VName [VName]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm [VName]
faml (Map VName [VName] -> Map VName [VName])
-> Map VName [VName] -> Map VName [VName]
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName [VName]
soacs FusionGEnv
env,
varsInScope :: Map VName VarEntry
varsInScope =
VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
VName
nm
( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
nm (LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
t) Names
forall a. Monoid a => a
mempty (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm Type
t
)
(Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
}
varAliases :: VName -> FusionGM Names
varAliases :: VName -> FusionGM Names
varAliases VName
v =
(FusionGEnv -> Names) -> FusionGM Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Names) -> FusionGM Names)
-> (FusionGEnv -> Names) -> FusionGM Names
forall a b. (a -> b) -> a -> b
$
(VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> (FusionGEnv -> Names) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases
(Maybe VarEntry -> Names)
-> (FusionGEnv -> Maybe VarEntry) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v
(Map VName VarEntry -> Maybe VarEntry)
-> (FusionGEnv -> Map VName VarEntry)
-> FusionGEnv
-> Maybe VarEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
varsAliases :: Names -> FusionGM Names
varsAliases :: Names -> FusionGM Names
varsAliases = ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (FusionGM [Names] -> FusionGM Names)
-> (Names -> FusionGM [Names]) -> Names -> FusionGM Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases ([VName] -> FusionGM [Names])
-> (Names -> [VName]) -> Names -> FusionGM [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
ip_vs, [VName]
other_infuse_vs) = do
FusedRes
res' <- (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
res ([VName]
ip_vs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
other_infuse_vs)
Names
aliases <- [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases [VName]
ip_vs
let inspectKer :: FusedKer -> FusedKer
inspectKer FusedKer
k = FusedKer
k {inplace :: Names
inplace = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedKer -> Names
inplace FusedKer
k}
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res' {kernels :: Map KernName FusedKer
kernels = (FusedKer -> FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b k. (a -> b) -> Map k a -> Map k b
M.map FusedKer -> FusedKer
inspectKer (Map KernName FusedKer -> Map KernName FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
res'}
checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates FusedRes
res (BasicOp (Update Safety
_ VName
src Slice SubExp
slice SubExp
_)) = do
let ifvs :: [VName]
ifvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice
FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName
src], [VName]
ifvs)
checkForUpdates FusedRes
res (BasicOp (FlatUpdate VName
src FlatSlice SubExp
slice VName
_)) = do
let ifvs :: [VName]
ifvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ FlatSlice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn FlatSlice SubExp
slice
FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName
src], [VName]
ifvs)
checkForUpdates FusedRes
res (Op (Futhark.Scatter _ _ _ written_info)) = do
let updt_arrs :: [VName]
updt_arrs = ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
x) -> VName
x) [(Shape, Int, VName)]
written_info
FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
updt_arrs, [])
checkForUpdates FusedRes
res Exp
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res
bindingFamily :: Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily :: Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pat
pat = (FusionGEnv -> FusionGEnv)
-> FusionGM FusedRes -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local FusionGEnv -> FusionGEnv
bind
where
idents :: [Ident]
idents = PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
pat
family :: [VName]
family = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
bind :: FusionGEnv -> FusionGEnv
bind FusionGEnv
env = (FusionGEnv -> Ident -> FusionGEnv)
-> FusionGEnv -> [Ident] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
family) FusionGEnv
env [Ident]
idents
bindingTransform :: PatElem -> VName -> SOAC.ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform :: PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElem
pe VName
srcname ArrayTransform
trns = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a)
-> (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall a b. (a -> b) -> a -> b
$ \FusionGEnv
env ->
case VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
srcname (Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> Maybe VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env of
Just (IsArray VName
src' NameInfo SOACS
_ Names
aliases Input
input) ->
FusionGEnv
env
{ varsInScope :: Map VName VarEntry
varsInScope =
VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
VName
vname
( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
src' (LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName Type
LetDec SOACS
dec) (VName -> Names
oneName VName
srcname Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases) (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
ArrayTransform
trns ArrayTransform -> Input -> Input
`SOAC.addTransform` Input
input
)
(Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
}
Maybe VarEntry
_ -> FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (PatElemT Type -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT Type
PatElem
pe, VName -> Names
oneName VName
vname)
where
vname :: VName
vname = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
pe
dec :: Type
dec = PatElemT Type -> Type
forall dec. PatElemT dec -> dec
patElemDec PatElemT Type
PatElem
pe
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
rrr = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\FusionGEnv
x -> FusionGEnv
x {fusedRes :: FusedRes
fusedRes = FusedRes
rrr})
runFusionGatherM ::
MonadFreshNames m =>
FusionGM a ->
FusionGEnv ->
m (Either Error a)
runFusionGatherM :: FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM (FusionGM ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) FusionGEnv
env =
(VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a))
-> (VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src -> Reader FusionGEnv (Either Error a, VNameSource)
-> FusionGEnv -> (Either Error a, VNameSource)
forall r a. Reader r a -> r -> a
runReader (StateT VNameSource (Reader FusionGEnv) (Either Error a)
-> VNameSource -> Reader FusionGEnv (Either Error a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
-> StateT VNameSource (Reader FusionGEnv) (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) VNameSource
src) FusionGEnv
env
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
{ passName :: String
passName = String
"Fuse SOACs",
passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
prog ->
Prog SOACS -> PassM (Prog SOACS)
simplifySOACS (Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Prog SOACS -> PassM (Prog SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Prog rep -> m (Prog rep)
renameProg
(Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
(Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts ([FunDef SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog)))
Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
Prog SOACS
prog
}
fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts Names
used_consts Stms SOACS
consts =
Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms Scope SOACS
forall a. Monoid a => a
mempty Stms SOACS
consts (Result -> PassM (Stms SOACS)) -> Result -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
used_consts
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
Stms SOACS
stms <-
Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms
(Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (FunDef SOACS -> [FParam SOACS]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef SOACS
fun))
(BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef SOACS
fun)
(BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef SOACS
fun)
let body :: BodyT SOACS
body = (FunDef SOACS -> BodyT SOACS
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms}
FunDef SOACS -> PassM (FunDef SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SOACS
fun {funDefBody :: BodyT SOACS
funDefBody = BodyT SOACS
body}
fuseStms :: Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms :: Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms Scope SOACS
scope Stms SOACS
stms Result
res = do
let env :: FusionGEnv
env =
FusionGEnv :: Map VName [VName] -> Map VName VarEntry -> FusedRes -> FusionGEnv
FusionGEnv
{ soacs :: Map VName [VName]
soacs = Map VName [VName]
forall k a. Map k a
M.empty,
varsInScope :: Map VName VarEntry
varsInScope = Map VName VarEntry
forall a. Monoid a => a
mempty,
fusedRes :: FusedRes
fusedRes = FusedRes
forall a. Monoid a => a
mempty
}
FusedRes
k <-
FusedRes -> FusedRes
cleanFusionResult
(FusedRes -> FusedRes) -> PassM FusedRes -> PassM FusedRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PassM (Either Error FusedRes) -> PassM FusedRes
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM
( FusionGM FusedRes -> FusionGEnv -> PassM (Either Error FusedRes)
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM
([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
forall a. Monoid a => a
mempty (Stms SOACS -> [Stm]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms) Result
res)
FusionGEnv
env
)
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FusedRes -> Bool
rsucc FusedRes
k
then Stms SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
stms
else PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM (PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS))
-> PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusionGM (Stms SOACS)
-> FusionGEnv -> PassM (Either Error (Stms SOACS))
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM ([(Ident, Names)] -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusedRes -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
k (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms) FusionGEnv
env
where
scope' :: [(Ident, Names)]
scope' = ((VName, NameInfo SOACS) -> (Ident, Names))
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, NameInfo SOACS) -> (Ident, Names)
forall t b. (Typed t, Monoid b) => (VName, t) -> (Ident, b)
toBind ([(VName, NameInfo SOACS)] -> [(Ident, Names)])
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> [(VName, NameInfo SOACS)]
forall k a. Map k a -> [(k, a)]
M.toList Scope SOACS
scope
toBind :: (VName, t) -> (Ident, b)
toBind (VName
k, t
t) = (VName -> Type -> Ident
Ident VName
k (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
t, b
forall a. Monoid a => a
mempty)
newtype KernName = KernName {KernName -> VName
unKernName :: VName}
deriving (KernName -> KernName -> Bool
(KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool) -> Eq KernName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernName -> KernName -> Bool
$c/= :: KernName -> KernName -> Bool
== :: KernName -> KernName -> Bool
$c== :: KernName -> KernName -> Bool
Eq, Eq KernName
Eq KernName
-> (KernName -> KernName -> Ordering)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> KernName)
-> (KernName -> KernName -> KernName)
-> Ord KernName
KernName -> KernName -> Bool
KernName -> KernName -> Ordering
KernName -> KernName -> KernName
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernName -> KernName -> KernName
$cmin :: KernName -> KernName -> KernName
max :: KernName -> KernName -> KernName
$cmax :: KernName -> KernName -> KernName
>= :: KernName -> KernName -> Bool
$c>= :: KernName -> KernName -> Bool
> :: KernName -> KernName -> Bool
$c> :: KernName -> KernName -> Bool
<= :: KernName -> KernName -> Bool
$c<= :: KernName -> KernName -> Bool
< :: KernName -> KernName -> Bool
$c< :: KernName -> KernName -> Bool
compare :: KernName -> KernName -> Ordering
$ccompare :: KernName -> KernName -> Ordering
$cp1Ord :: Eq KernName
Ord, Int -> KernName -> ShowS
[KernName] -> ShowS
KernName -> String
(Int -> KernName -> ShowS)
-> (KernName -> String) -> ([KernName] -> ShowS) -> Show KernName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernName] -> ShowS
$cshowList :: [KernName] -> ShowS
show :: KernName -> String
$cshow :: KernName -> String
showsPrec :: Int -> KernName -> ShowS
$cshowsPrec :: Int -> KernName -> ShowS
Show)
data FusedRes = FusedRes
{
FusedRes -> Bool
rsucc :: Bool,
FusedRes -> Map VName KernName
outArr :: M.Map VName KernName,
FusedRes -> Map VName (Set KernName)
inpArr :: M.Map VName (S.Set KernName),
FusedRes -> Names
infusible :: Names,
FusedRes -> Map KernName FusedKer
kernels :: M.Map KernName FusedKer
}
instance Semigroup FusedRes where
FusedRes
res1 <> :: FusedRes -> FusedRes -> FusedRes
<> FusedRes
res2 =
Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
(FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
(FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
(FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2)
(FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)
instance Monoid FusedRes where
mempty :: FusedRes
mempty =
FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
{ rsucc :: Bool
rsucc = Bool
False,
outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
}
isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers :: FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
ress Set KernName
kers VName
nm =
case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress) of
Maybe (Set KernName)
Nothing -> Bool
False
Just Set KernName
s -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName -> Bool
forall a. Set a -> Bool
S.null (Set KernName -> Bool) -> Set KernName -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName
s Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set KernName
kers
getKersWithInpArrs :: FusedRes -> [VName] -> S.Set KernName
getKersWithInpArrs :: FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
ress =
[Set KernName] -> Set KernName
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set KernName] -> Set KernName)
-> ([VName] -> [Set KernName]) -> [VName] -> Set KernName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Maybe (Set KernName)) -> [VName] -> [Set KernName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress)
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr =
([VName] -> VName -> FusionGM [VName])
-> [VName] -> [VName] -> FusionGM [VName]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \[VName]
y VName
nm -> do
Maybe [VName]
stm <- (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName]))
-> (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall a b. (a -> b) -> a -> b
$ VName -> Map VName [VName] -> Maybe [VName]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (Map VName [VName] -> Maybe [VName])
-> (FusionGEnv -> Map VName [VName]) -> FusionGEnv -> Maybe [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName [VName]
soacs
case Maybe [VName]
stm of
Maybe [VName]
Nothing -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
nm])
Just [VName]
nns -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
nns)
)
[]
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac = do
let ([VName]
inp_idds, [VName]
other_idds) = [Input] -> ([VName], [VName])
getIdentArr ([Input] -> ([VName], [VName])) -> [Input] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
([VName]
inp_nms0, [VName]
other_nms0) = ([VName]
inp_idds, [VName]
other_idds)
[VName]
inp_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
inp_nms0
[VName]
other_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
other_nms0
([VName], [VName]) -> FusionGM ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
inp_nms, [VName]
other_nms)
addNewKerWithInfusible :: FusedRes -> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible :: FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res ([Ident]
idd, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs = do
KernName
nm_ker <- VName -> KernName
KernName (VName -> KernName) -> FusionGM VName -> FusionGM KernName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ker"
Scope SOACS
scope <- FusionGM (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let out_nms :: [VName]
out_nms = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
idd
new_ker :: FusedKer
new_ker = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope
comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union
os' :: Map VName KernName
os' =
[(VName, KernName)] -> Map VName KernName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
arr, KernName
nm_ker) | VName
arr <- [VName]
out_nms]
Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res
is' :: Map VName (Set KernName)
is' =
[(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (VName
arr, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
nm_ker)
| VName
arr <- (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
]
Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
(FusedRes -> Bool
rsucc FusedRes
res)
Map VName KernName
os'
Map VName (Set KernName)
is'
Names
ufs
(KernName
-> FusedKer -> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernName
nm_ker FusedKer
new_ker (FusedRes -> Map KernName FusedKer
kernels FusedRes
res))
lookupInput :: VName -> FusionGM (Maybe SOAC.Input)
lookupInput :: VName -> FusionGM (Maybe Input)
lookupInput VName
name = (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name
inlineSOACInput :: SOAC.Input -> FusionGM SOAC.Input
inlineSOACInput :: Input -> FusionGM Input
inlineSOACInput (SOAC.Input ArrayTransforms
ts VName
v Type
t) = do
Maybe Input
maybe_inp <- VName -> FusionGM (Maybe Input)
lookupInput VName
v
case Maybe Input
maybe_inp of
Maybe Input
Nothing ->
Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v Type
t
Just (SOAC.Input ArrayTransforms
ts2 VName
v2 Type
t2) ->
Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input (ArrayTransforms
ts2 ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts) VName
v2 Type
t2
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
soac = do
[Input]
inputs' <- (Input -> FusionGM Input) -> [Input] -> FusionGM [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Input -> FusionGM Input
inlineSOACInput ([Input] -> FusionGM [Input]) -> [Input] -> FusionGM [Input]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ [Input]
inputs' [Input] -> SOAC -> SOAC
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC
soac
greedyFuse ::
[Stm] ->
Names ->
FusedRes ->
(Pat, StmAux (), SOAC, Names) ->
FusionGM FusedRes
greedyFuse :: [Stm]
-> Names
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_stms Names
lam_used_nms FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
orig_soac, Names
consumed) = do
SOAC
soac <- SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
orig_soac
([VName]
inp_nms, [VName]
other_nms) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
let out_nms :: [VName]
out_nms = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds
isInfusible :: VName -> Bool
isInfusible = (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res)
is_screma :: Bool
is_screma = case SOAC
orig_soac of
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_ ->
(Maybe ([Reduce SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form))
Bool -> Bool -> Bool
&& Bool -> Bool
not (Maybe [Reduce SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe [Scan SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form))
SOAC
_ -> Bool
False
(Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
fused_nms, [FusedKer]
old_kers, [KernName]
oldker_nms) <-
if Bool
is_screma Bool -> Bool -> Bool
|| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isInfusible [VName]
out_nms
then [Stm]
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_stms FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
else FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
Names
all_used_names <-
([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (FusionGM [Names] -> FusionGM Names)
-> (Names -> FusionGM [Names]) -> Names -> FusionGM Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases ([VName] -> FusionGM [Names])
-> (Names -> [VName]) -> Names -> FusionGM [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> FusionGM Names) -> Names -> FusionGM Names
forall a b. (a -> b) -> a -> b
$
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat [Names
lam_used_nms, [VName] -> Names
namesFromList [VName]
inp_nms, [VName] -> Names
namesFromList [VName]
other_nms]
let has_inplace :: FusedKer -> Bool
has_inplace FusedKer
ker = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Bool
`namesIntersect` Names
all_used_names
ok_inplace :: Bool
ok_inplace = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (FusedKer -> Bool) -> [FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any FusedKer -> Bool
has_inplace [FusedKer]
old_kers
let fusible_ker :: Bool
fusible_ker = Bool -> Bool
not ([FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FusedKer]
old_kers) Bool -> Bool -> Bool
&& Bool
ok_inplace Bool -> Bool -> Bool
&& Bool
ok_kers_compat
let mod_kerS :: Set KernName
mod_kerS = if Bool
fusible_ker then [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList [KernName]
oldker_nms else Set KernName
forall a. Monoid a => a
mempty
let used_inps :: [VName]
used_inps = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
res Set KernName
mod_kerS) [VName]
inp_nms
let ufs :: Names
ufs =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
[ FusedRes -> Names
infusible FusedRes
res,
[VName] -> Names
namesFromList [VName]
used_inps,
[VName] -> Names
namesFromList [VName]
other_nms
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac)
]
let comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union
if Bool -> Bool
not Bool
fusible_ker
then FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res (PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs
else do
let inpArr' :: Map VName (Set KernName)
inpArr' =
(Map VName (Set KernName)
-> (FusedKer, KernName) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(FusedKer, KernName)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \Map VName (Set KernName)
inpa (FusedKer
kold, KernName
knm) ->
(Map VName (Set KernName) -> VName -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> Set VName
-> Map VName (Set KernName)
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl'
( \Map VName (Set KernName)
inpp VName
nm ->
case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm Map VName (Set KernName)
inpp of
Maybe (Set KernName)
Nothing -> Map VName (Set KernName)
inpp
Just Set KernName
s ->
let new_set :: Set KernName
new_set = KernName -> Set KernName -> Set KernName
forall a. Ord a => a -> Set a -> Set a
S.delete KernName
knm Set KernName
s
in if Set KernName -> Bool
forall a. Set a -> Bool
S.null Set KernName
new_set
then VName -> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
nm Map VName (Set KernName)
inpp
else VName
-> Set KernName
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm Set KernName
new_set Map VName (Set KernName)
inpp
)
Map VName (Set KernName)
inpa
(Set VName -> Map VName (Set KernName))
-> Set VName -> Map VName (Set KernName)
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
kold
)
(FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res)
([FusedKer] -> [KernName] -> [(FusedKer, KernName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FusedKer]
old_kers [KernName]
oldker_nms)
let fused_ker_nms :: [(KernName, FusedKer)]
fused_ker_nms = [KernName] -> [FusedKer] -> [(KernName, FusedKer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [KernName]
fused_nms [FusedKer]
fused_kers
inpArr'' :: Map VName (Set KernName)
inpArr'' =
(Map VName (Set KernName)
-> (KernName, FusedKer) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(KernName, FusedKer)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \Map VName (Set KernName)
inpa' (KernName
knm, FusedKer
knew) ->
[(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (VName
k, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
knm)
| VName
k <- Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
knew
]
Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` Map VName (Set KernName)
inpa'
)
Map VName (Set KernName)
inpArr'
[(KernName, FusedKer)]
fused_ker_nms
let kernels' :: Map KernName FusedKer
kernels' = [(KernName, FusedKer)] -> Map KernName FusedKer
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(KernName, FusedKer)]
fused_ker_nms Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes Bool
True (FusedRes -> Map VName KernName
outArr FusedRes
res) Map VName (Set KernName)
inpArr'' Names
ufs Map KernName FusedKer
kernels'
prodconsGreedyFuse ::
FusedRes ->
(Pat, StmAux (), SOAC, Names) ->
FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse :: FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
let out_nms :: [VName]
out_nms = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds
to_fuse_knmSet :: Set KernName
to_fuse_knmSet = FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res [VName]
out_nms
to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList Set KernName
to_fuse_knmSet
lookup_kern :: KernName -> FusionGM FusedKer
lookup_kern KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
Maybe FusedKer
Nothing ->
Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
)
Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker
[FusedKer]
to_fuse_kers <- (KernName -> FusionGM FusedKer)
-> [KernName] -> FusionGM [FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernName -> FusionGM FusedKer
lookup_kern [KernName]
to_fuse_knms
(Bool
ok_kers_compat, [FusedKer]
fused_kers) <- do
[Maybe FusedKer]
kers <-
[FusedKer]
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [FusedKer]
to_fuse_kers ((FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer])
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall a b. (a -> b) -> a -> b
$
Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
forall a. Monoid a => a
mempty (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds) SOAC
soac Names
consumed
case [Maybe FusedKer] -> Maybe [FusedKer]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe FusedKer]
kers of
Maybe [FusedKer]
Nothing -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, [])
Just [FusedKer]
kers' -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, (FusedKer -> FusedKer) -> [FusedKer] -> [FusedKer]
forall a b. (a -> b) -> [a] -> [b]
map FusedKer -> FusedKer
certifyKer [FusedKer]
kers')
(Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
to_fuse_knms, [FusedKer]
to_fuse_kers, [KernName]
to_fuse_knms)
where
certifyKer :: FusedKer -> FusedKer
certifyKer FusedKer
k = FusedKer
k {kerAux :: StmAux ()
kerAux = FusedKer -> StmAux ()
kerAux FusedKer
k StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux}
horizontGreedyFuse ::
[Stm] ->
FusedRes ->
(Pat, StmAux (), SOAC, Names) ->
FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse :: [Stm]
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_stms FusedRes
res (Pat
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
([VName]
inp_nms, [VName]
_) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
let out_nms :: [VName]
out_nms = PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
out_idds
infusible_nms :: Names
infusible_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res) [VName]
out_nms
out_arr_nms :: [VName]
out_arr_nms = case SOAC
soac of
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
_) [Input]
_ ->
Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan SOACS]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce SOACS]
reds) [VName]
out_nms
SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_ -> Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
out_nms
SOAC
_ -> [VName]
out_nms
to_fuse_knms1 :: [KernName]
to_fuse_knms1 = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res ([VName]
out_arr_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
inp_nms)
to_fuse_knms2 :: [KernName]
to_fuse_knms2 = SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize (SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac) FusedRes
res
to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList ([KernName] -> Set KernName) -> [KernName] -> Set KernName
forall a b. (a -> b) -> a -> b
$ [KernName]
to_fuse_knms1 [KernName] -> [KernName] -> [KernName]
forall a. [a] -> [a] -> [a]
++ [KernName]
to_fuse_knms2
lookupKernel :: KernName -> FusionGM FusedKer
lookupKernel KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
Maybe FusedKer
Nothing ->
Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
)
Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker
let stm_nms :: [[VName]]
stm_nms = (Stm -> [VName]) -> [Stm] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT Type -> [VName]) -> (Stm -> PatT Type) -> Stm -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> PatT Type
forall rep. Stm rep -> Pat rep
stmPat) [Stm]
rem_stms
[Maybe (FusedKer, KernName, Int)]
kernminds <- [KernName]
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [KernName]
to_fuse_knms ((KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)])
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ \KernName
ker_nm -> do
FusedKer
ker <- KernName -> FusionGM FusedKer
lookupKernel KernName
ker_nm
case (VName -> Maybe Int) -> [VName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\VName
out_nm -> ([VName] -> Bool) -> [[VName]] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem VName
out_nm) [[VName]]
stm_nms) (FusedKer -> [VName]
outNames FusedKer
ker) of
[] -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (FusedKer, KernName, Int)
forall a. Maybe a
Nothing
[Int]
is -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int)))
-> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall a b. (a -> b) -> a -> b
$ (FusedKer, KernName, Int) -> Maybe (FusedKer, KernName, Int)
forall a. a -> Maybe a
Just (FusedKer
ker, KernName
ker_nm, [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum [Int]
is)
Scope SOACS
scope <- FusionGM (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let kernminds' :: [(FusedKer, KernName, Int)]
kernminds' = ((FusedKer, KernName, Int)
-> (FusedKer, KernName, Int) -> Ordering)
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy (\(FusedKer
_, KernName
_, Int
i1) (FusedKer
_, KernName
_, Int
i2) -> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
i1 Int
i2) ([(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)])
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ [Maybe (FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (FusedKer, KernName, Int)]
kernminds
soac_kernel :: FusedKer
soac_kernel = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope
(Bool
_, Int
ok_ind, Int
_, FusedKer
fused_ker, Names
_) <-
((Bool, Int, Int, FusedKer, Names)
-> (FusedKer, KernName, Int)
-> FusionGM (Bool, Int, Int, FusedKer, Names))
-> (Bool, Int, Int, FusedKer, Names)
-> [(FusedKer, KernName, Int)]
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \(Bool
cur_ok, Int
n, Int
prev_ind, FusedKer
cur_ker, Names
ufus_nms) (FusedKer
ker, KernName
_ker_nm, Int
stm_ind) -> do
let curker_outnms :: [VName]
curker_outnms = FusedKer -> [VName]
outNames FusedKer
cur_ker
curker_outset :: Names
curker_outset = [VName] -> Names
namesFromList [VName]
curker_outnms
new_ufus_nms :: Names
new_ufus_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Names -> [VName]
namesToList Names
ufus_nms
out_transf_ok :: Bool
out_transf_ok =
let ker_inp :: [Input]
ker_inp = SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs (SOAC -> [Input]) -> SOAC -> [Input]
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
unfuse1 :: Names
unfuse1 =
[VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray [Input]
ker_inp)
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarInput [Input]
ker_inp)
unfuse2 :: Names
unfuse2 = Names -> Names -> Names
namesIntersection Names
curker_outset Names
ufus_nms
in Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
unfuse1 Names -> Names -> Bool
`namesIntersect` Names
unfuse2
cons_no_out_transf :: Bool
cons_no_out_transf = ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
let consumer_ok :: Bool
consumer_ok =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
Names
curker_outset
Names -> Names -> Bool
`namesIntersect` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody (Lambda SOACS -> BodyT SOACS) -> Lambda SOACS -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda (SOAC -> Lambda SOACS) -> SOAC -> Lambda SOACS
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker)
let interm_stms_ok :: Bool
interm_stms_ok =
Bool
cur_ok Bool -> Bool -> Bool
&& Bool
consumer_ok Bool -> Bool -> Bool
&& Bool
out_transf_ok Bool -> Bool -> Bool
&& Bool
cons_no_out_transf
Bool -> Bool -> Bool
&& (Bool -> Stm -> Bool) -> Bool -> [Stm] -> Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \Bool
ok Stm
stm ->
Bool
ok
Bool -> Bool -> Bool
&& Bool -> Bool
not (Names
curker_outset Names -> Names -> Bool
`namesIntersect` Exp -> Names
forall a. FreeIn a => a -> Names
freeIn (Stm -> Exp
forall rep. Stm rep -> Exp rep
stmExp Stm
stm))
Bool -> Bool -> Bool
||
Bool -> Bool
not
( [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
[VName]
curker_outnms
[VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (Stm -> Pat
forall rep. Stm rep -> Pat rep
stmPat Stm
stm)
)
)
Bool
True
(Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
drop (Int
prev_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Stm] -> [Stm]) -> [Stm] -> [Stm]
forall a b. (a -> b) -> a -> b
$ Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
take Int
stm_ind [Stm]
rem_stms)
if Bool -> Bool
not Bool
interm_stms_ok
then (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
stm_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
else do
Maybe FusedKer
new_ker <-
Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion
Names
ufus_nms
(FusedKer -> [VName]
outNames FusedKer
cur_ker)
(FusedKer -> SOAC
fsoac FusedKer
cur_ker)
(FusedKer -> Names
fusedConsumed FusedKer
cur_ker)
FusedKer
ker
case Maybe FusedKer
new_ker of
Maybe FusedKer
Nothing -> (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
stm_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
Just FusedKer
krn ->
let krn' :: FusedKer
krn' = FusedKer
krn {kerAux :: StmAux ()
kerAux = StmAux ()
aux StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> FusedKer -> StmAux ()
kerAux FusedKer
krn}
in (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
stm_ind, FusedKer
krn', Names
new_ufus_nms)
)
(Bool
True, Int
0, Int
0, FusedKer
soac_kernel, Names
infusible_nms)
[(FusedKer, KernName, Int)]
kernminds'
let ([FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms', [Int]
_) = [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int]))
-> [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b. (a -> b) -> a -> b
$ Int -> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. Int -> [a] -> [a]
take Int
ok_ind [(FusedKer, KernName, Int)]
kernminds'
new_kernms :: [KernName]
new_kernms = Int -> [KernName] -> [KernName]
forall a. Int -> [a] -> [a]
drop (Int
ok_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [KernName]
to_fuse_knms'
(Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
ok_ind Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0, [FusedKer
fused_ker], [KernName]
new_kernms, [FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms')
where
getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize SubExp
sz FusedRes
ress =
((KernName, FusedKer) -> KernName)
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> [a] -> [b]
map (KernName, FusedKer) -> KernName
forall a b. (a, b) -> a
fst ([(KernName, FusedKer)] -> [KernName])
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> a -> b
$ ((KernName, FusedKer) -> Bool)
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(KernName
_, FusedKer
ker) -> SubExp
sz SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width (FusedKer -> SOAC
fsoac FusedKer
ker)) ([(KernName, FusedKer)] -> [(KernName, FusedKer)])
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ Map KernName FusedKer -> [(KernName, FusedKer)]
forall k a. Map k a -> [(k, a)]
M.toList (Map KernName FusedKer -> [(KernName, FusedKer)])
-> Map KernName FusedKer -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
ress
fusionGatherBody :: FusedRes -> Body -> FusionGM FusedRes
fusionGatherBody :: FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
fres (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) =
FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres (Stms SOACS -> [Stm]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms) Result
res
fusionGatherStms :: FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms :: FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms
FusedRes
fres
(Let (Pat [PatElem]
pes) StmAux (ExpDec SOACS)
stmtp (DoLoop [(FParam SOACS, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
w [(LParam SOACS, VName)]
loop_vars) BodyT SOACS
body) : [Stm]
stms)
Result
res
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Param Type, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars = do
let ([Param DeclType]
merge_params, [SubExp]
merge_init) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
([Param Type]
loop_params, [VName]
loop_arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars
Param Type
chunk_param <- String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"chunk_size" (Type -> FusionGM (Param Type)) -> Type -> FusionGM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
offset_param <- String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"offset" (Type -> FusionGM (Param Type)) -> Type -> FusionGM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
let offset :: VName
offset = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
offset_param
chunk_size :: VName
chunk_size = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param
[Param Type]
acc_params <- [Param DeclType]
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
merge_params ((Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type])
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \Param DeclType
p ->
String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_outer") (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p)
[Param Type]
chunked_params <- [(Param Type, VName)]
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars (((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type])
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
String -> Type -> FusionGM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam
(VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_chunk")
(Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Futhark.Var VName
chunk_size)
let lam_params :: [Param Type]
lam_params = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param] [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
chunked_params
BodyT SOACS
lam_body <- Builder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
Scope SOACS
-> Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
lam_params) (Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
let merge' :: [(Param DeclType, SubExp)]
merge' = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
merge_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
acc_params
VName
j <- String -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"j"
BodyT SOACS
loop_body <- Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> Builder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
[(Param Type, Param Type)]
-> ((Param Type, Param Type)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
loop_params [Param Type]
chunked_params) (((Param Type, Param Type)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ())
-> ((Param Type, Param Type)
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, Param Type
a_p) ->
[VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
a_p) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
a_p) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Futhark.Var VName
j]
[VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
i] (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
j)
BodyT SOACS -> Builder SOACS (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SOACS
body
[BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))]
-> BuilderT
SOACS
(State VNameSource)
(Body (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ Exp -> BuilderT SOACS (State VNameSource) Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> BuilderT SOACS (State VNameSource) Exp)
-> Exp -> BuilderT SOACS (State VNameSource) Exp
forall a b. (a -> b) -> a -> b
$
[(FParam SOACS, SubExp)] -> LoopForm SOACS -> BodyT SOACS -> Exp
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge' (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
j IntType
it (VName -> SubExp
Futhark.Var VName
chunk_size) []) BodyT SOACS
loop_body,
Exp -> BuilderT SOACS (State VNameSource) Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> BuilderT SOACS (State VNameSource) Exp)
-> Exp -> BuilderT SOACS (State VNameSource) Exp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
chunk_size)
]
let lam :: Lambda SOACS
lam =
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
lam_params,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
lam_body,
lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param]
}
stream :: SOAC SOACS
stream = SubExp
-> [VName]
-> StreamForm SOACS
-> [SubExp]
-> Lambda SOACS
-> SOAC SOACS
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.Stream SubExp
w [VName]
loop_arrs StreamForm SOACS
forall rep. StreamForm rep
Sequential ([SubExp]
merge_init [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
it Integer
0]) Lambda SOACS
lam
VName
discard <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"discard"
let discard_pe :: PatElemT Type
discard_pe = VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
discard (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms
FusedRes
fres
(Pat -> StmAux (ExpDec SOACS) -> Exp -> Stm
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type]
[PatElem]
pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type
discard_pe])) StmAux (ExpDec SOACS)
stmtp (Op SOACS -> Exp
forall rep. Op rep -> ExpT rep
Op Op SOACS
SOAC SOACS
stream) Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms)
Result
res
fusionGatherStms FusedRes
fres (stm :: Stm
stm@(Let Pat
pat StmAux (ExpDec SOACS)
_ Exp
e) : [Stm]
stms) Result
res = do
Either NotSOAC SOAC
maybesoac <- Exp -> FusionGM (Either NotSOAC SOAC)
forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
SOAC.fromExp Exp
e
case Either NotSOAC SOAC
maybesoac of
Right soac :: SOAC
soac@(SOAC.Scatter SubExp
_len Lambda SOACS
lam [Input]
_ivs [(Shape, Int, VName)]
_as) -> do
FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
FusedRes
fres'' <- FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates FusedRes
fres'' Exp
e
Right soac :: SOAC
soac@(SOAC.Hist SubExp
_ [HistOp SOACS]
_ Lambda SOACS
lam [Input]
_) -> do
FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
Right soac :: SOAC
soac@(SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
_) ->
SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac ((Scan SOACS -> Lambda SOACS) -> [Scan SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda [Scan SOACS]
scans [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> Lambda SOACS) -> [Reduce SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda [Reduce SOACS]
reds [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> [Lambda SOACS
map_lam]) ([SubExp] -> FusionGM FusedRes) -> [SubExp] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
(Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds
Right soac :: SOAC
soac@(SOAC.Stream SubExp
_ StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
_) -> do
let lambdas :: [Lambda SOACS]
lambdas = case StreamForm SOACS
form of
Parallel StreamOrd
_ Commutativity
_ Lambda SOACS
lout -> [Lambda SOACS
lout, Lambda SOACS
lam]
StreamForm SOACS
Sequential -> [Lambda SOACS
lam]
SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes
Either NotSOAC SOAC
_
| Pat [PatElem
pe] <- Pat
pat,
Just (VName
src, ArrayTransform
trns) <- Certs -> Exp -> Maybe (VName, ArrayTransform)
forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
SOAC.transformFromExp (Stm -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm
stm) Exp
e ->
PatElem
-> VName
-> ArrayTransform
-> FusionGM FusedRes
-> FusionGM FusedRes
forall a.
PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElem
pe VName
src ArrayTransform
trns (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
stms Result
res
| Bool
otherwise -> do
let pat_vars :: [Exp]
pat_vars = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> (VName -> BasicOp) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat
FusedRes
bres <- Pat -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPat Pat
pat Exp
e (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
stms Result
res
FusedRes
bres' <- FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates FusedRes
bres Exp
e
(FusedRes -> Exp -> FusionGM FusedRes)
-> FusedRes -> [Exp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp FusedRes
bres' (Exp
e Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
pat_vars)
where
aux :: StmAux (ExpDec SOACS)
aux = Stm -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm
stm
rem_stms :: [Stm]
rem_stms = Stm
stm Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms
consumed :: Names
consumed = Exp (Aliases SOACS) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp (Exp (Aliases SOACS) -> Names) -> Exp (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Exp -> Exp (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Exp rep -> Exp (Aliases rep)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp
e
reduceLike :: SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes = do
(Names
used_lam, FusedRes
lres) <- ((Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes))
-> (Names, FusedRes)
-> [Lambda SOACS]
-> FusionGM (Names, FusedRes)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) [Lambda SOACS]
lambdas
FusedRes
bres <- Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pat
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
lres [Stm]
stms Result
res
FusedRes
bres' <- (FusedRes -> SubExp -> FusionGM FusedRes)
-> FusedRes -> [SubExp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
bres [SubExp]
nes
Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
[Stm]
-> Names
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_stms Names
used_lam FusedRes
bres' (Pat
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')
mapLike :: FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lambda = do
FusedRes
bres <- Pat -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pat
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms FusedRes
fres' [Stm]
stms Result
res
(Names
used_lam, FusedRes
blres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
bres) Lambda SOACS
lambda
Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
[Stm]
-> Names
-> FusedRes
-> (Pat, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_stms Names
used_lam FusedRes
blres (Pat
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')
fusionGatherStms FusedRes
fres [] Result
res =
(FusedRes -> Exp -> FusionGM FusedRes)
-> FusedRes -> [Exp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp FusedRes
fres ([Exp] -> FusionGM FusedRes) -> [Exp] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp) -> Result -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> (SubExpRes -> BasicOp) -> SubExpRes -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp FusedRes
fres (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form BodyT SOACS
loop_body) = do
FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
let form_idents :: [Ident]
form_idents =
case LoopForm SOACS
form of
ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
it)) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
: ((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
WhileLoop {} -> []
FusedRes
new_res <-
[(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Ident]
form_idents [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ ((Param DeclType, SubExp) -> Ident)
-> [(Param DeclType, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param DeclType -> Ident)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
loop_body
let ([VName]
inp_arrs, [Set KernName]
_) = [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Set KernName)] -> ([VName], [Set KernName]))
-> [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [(VName, Set KernName)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName (Set KernName) -> [(VName, Set KernName)])
-> Map VName (Set KernName) -> [(VName, Set KernName)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_arrs)}
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres'
fusionGatherExp FusedRes
fres (If SubExp
cond BodyT SOACS
e_then BodyT SOACS
e_else IfDec (BranchType SOACS)
_) = do
FusedRes
then_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_then
FusedRes
else_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_else
let both_res :: FusedRes
both_res = FusedRes
then_res FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
else_res
FusedRes
fres' <- FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres SubExp
cond
FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
fres' FusedRes
both_res
fusionGatherExp FusedRes
fres (WithAcc [WithAccInput SOACS]
inps Lambda SOACS
lam) = do
(Names
_, FusedRes
fres') <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) Lambda SOACS
lam
FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres' (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn [WithAccInput SOACS]
inps
fusionGatherExp FusedRes
_ (Op Futhark.Screma {}) = String -> FusionGM FusedRes
errorIllegal String
"screma"
fusionGatherExp FusedRes
_ (Op Futhark.Scatter {}) = String -> FusionGM FusedRes
errorIllegal String
"write"
fusionGatherExp FusedRes
fres Exp
e = FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Exp -> Names
forall a. FreeIn a => a -> Names
freeIn Exp
e
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres (Var VName
idd) = FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
idd
fusionGatherSubExp FusedRes
fres SubExp
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres = (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres ([VName] -> FusionGM FusedRes)
-> (Names -> [VName]) -> Names -> FusionGM FusedRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
name = do
Maybe Input
trns <- (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name
let name' :: VName
name' = case Maybe Input
trns of
Maybe Input
Nothing -> VName
name
Just (SOAC.Input ArrayTransforms
_ VName
orig Type
_) -> VName
orig
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres {infusible :: Names
infusible = VName -> Names
oneName VName
name' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
fres}
fusionGatherLam :: (Names, FusedRes) -> Lambda -> FusionGM (Names, FusedRes)
fusionGatherLam :: (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
u_set, FusedRes
fres) (Lambda [LParam SOACS]
idds BodyT SOACS
body [Type]
_) = do
FusedRes
new_res <- [Param Type] -> FusionGM FusedRes -> FusionGM FusedRes
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
idds (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
body
let inp_arrs :: Names
inp_arrs = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
let unfus :: Names
unfus = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
inp_arrs
[VName]
stms <- (FusionGEnv -> [VName]) -> FusionGM [VName]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> [VName]) -> FusionGM [VName])
-> (FusionGEnv -> [VName]) -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName VarEntry -> [VName])
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
let unfus' :: Names
unfus' = Names
unfus Names -> Names -> Names
`namesIntersection` [VName] -> Names
namesFromList [VName]
stms
let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = Names
unfus'}
(Names, FusedRes) -> FusionGM (Names, FusedRes)
forall (m :: * -> *) a. Monad m => a -> m a
return (Names
u_set Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus', FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres)
fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms
| Just (Let Pat
pat StmAux (ExpDec SOACS)
aux Exp
e, Stms SOACS
stms') <- Stms SOACS -> Maybe (Stm, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms = do
Stms SOACS
stms'' <- Pat -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. Pat -> FusionGM a -> FusionGM a
bindingPat Pat
pat (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms'
Stms SOACS
soac_stms <- Pat -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC Pat
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp
e
Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> FusionGM (Stms SOACS))
-> Stms SOACS -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
soac_stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms''
| Bool
otherwise =
Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty
fuseInBody :: Body -> FusionGM Body
fuseInBody :: BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) =
BodyDec SOACS -> Stms SOACS -> Result -> BodyT SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () (Stms SOACS -> Result -> BodyT SOACS)
-> FusionGM (Stms SOACS) -> FusionGM (Result -> BodyT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms FusionGM (Result -> BodyT SOACS)
-> FusionGM Result -> FusionGM (BodyT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> FusionGM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
fuseInExp :: Exp -> FusionGM Exp
fuseInExp :: Exp -> FusionGM Exp
fuseInExp (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form BodyT SOACS
loopbody) =
[(Ident, Names)] -> FusionGM Exp -> FusionGM Exp
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
form_idents ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) (FusionGM Exp -> FusionGM Exp) -> FusionGM Exp -> FusionGM Exp
forall a b. (a -> b) -> a -> b
$
[Param DeclType] -> FusionGM Exp -> FusionGM Exp
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) (FusionGM Exp -> FusionGM Exp) -> FusionGM Exp -> FusionGM Exp
forall a b. (a -> b) -> a -> b
$
[(FParam SOACS, SubExp)] -> LoopForm SOACS -> BodyT SOACS -> Exp
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form (BodyT SOACS -> Exp) -> FusionGM (BodyT SOACS) -> FusionGM Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
loopbody
where
form_idents :: [Ident]
form_idents = case LoopForm SOACS
form of
WhileLoop {} -> []
ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
:
((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
fuseInExp Exp
e = Mapper SOACS SOACS FusionGM -> Exp -> FusionGM Exp
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS FusionGM
fuseIn Exp
e
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn =
Mapper SOACS SOACS FusionGM
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
mapOnBody = (BodyT SOACS -> FusionGM (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
forall a b. a -> b -> a
const BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody,
mapOnOp :: Op SOACS -> FusionGM (Op SOACS)
mapOnOp = SOACMapper SOACS SOACS FusionGM
-> SOAC SOACS -> FusionGM (SOAC SOACS)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper Any Any FusionGM
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda}
}
fuseInLambda :: Lambda -> FusionGM Lambda
fuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rtp) = do
BodyT SOACS
body' <- [Param Type] -> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
params (FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
body
Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda SOACS -> FusionGM (Lambda SOACS))
-> Lambda SOACS -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam SOACS]
params BodyT SOACS
body' [Type]
rtp
replaceSOAC :: Pat -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC :: Pat -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC (Pat []) StmAux ()
_ Exp
_ = Stms SOACS -> FusionGM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
forall a. Monoid a => a
mempty
replaceSOAC pat :: Pat
pat@(Pat (PatElem
patElem : [PatElem]
_)) StmAux ()
aux Exp
e = do
FusedRes
fres <- (FusionGEnv -> FusedRes) -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks FusionGEnv -> FusedRes
fusedRes
let pat_nm :: VName
pat_nm = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
patElem
names :: [Ident]
names = PatT Type -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT Type
Pat
pat
case VName -> Map VName KernName -> Maybe KernName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_nm (FusedRes -> Map VName KernName
outArr FusedRes
fres) of
Maybe KernName
Nothing ->
Stm -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm -> Stms SOACS) -> (Exp -> Stm) -> Exp -> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> StmAux (ExpDec SOACS) -> Exp -> Stm
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp -> Stms SOACS) -> FusionGM Exp -> FusionGM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> FusionGM Exp
fuseInExp Exp
e
Just KernName
knm ->
case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
knm (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres) of
Maybe FusedKer
Nothing ->
Error -> FusionGM (Stms SOACS)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM (Stms SOACS)) -> Error -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, replaceSOAC, outArr in ker_name "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"which is not in Res: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (KernName -> VName
unKernName KernName
knm)
)
Just FusedKer
ker -> do
Bool -> FusionGM () -> FusionGM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
fusedVars FusedKer
ker) (FusionGM () -> FusionGM ()) -> FusionGM () -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
Error -> FusionGM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM ()) -> Error -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, replaceSOAC, unfused kernel "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"still in result: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Ident] -> String
forall a. Pretty a => a -> String
pretty [Ident]
names
)
StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux (FusedKer -> [VName]
outNames FusedKer
ker) FusedKer
ker
insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux [VName]
names FusedKer
ker = do
SOAC
new_soac' <- SOAC -> FusionGM SOAC
finaliseSOAC (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
BuilderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (BuilderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS))
-> BuilderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
SOAC SOACS
f_soac <- SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(SOAC (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
SOAC.toSOAC SOAC (Rep (BuilderT SOACS (State VNameSource)))
SOAC
new_soac'
SOAC SOACS
f_soac' <- Names
-> SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed (FusedKer -> Names
fusedConsumed FusedKer
ker) (SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable -> SOAC SOACS -> OpWithAliases (SOAC SOACS)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
forall a. Monoid a => a
mempty SOAC SOACS
f_soac
[Ident]
validents <- (String -> Type -> BuilderT SOACS (State VNameSource) Ident)
-> [String] -> [Type] -> BuilderT SOACS (State VNameSource) [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> BuilderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
baseString [VName]
names) ([Type] -> BuilderT SOACS (State VNameSource) [Ident])
-> [Type] -> BuilderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall rep. SOAC rep -> [Type]
SOAC.typeOf SOAC
new_soac'
StmAux ()
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (FusedKer -> StmAux ()
kerAux FusedKer
ker StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux) (BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (BuilderT SOACS (State VNameSource)))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([Ident] -> PatT Type
basicPat [Ident]
validents) (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp
forall rep. Op rep -> ExpT rep
Op Op SOACS
SOAC SOACS
f_soac'
ArrayTransforms
-> [VName] -> [Ident] -> BuilderT SOACS (State VNameSource) ()
transformOutput (FusedKer -> ArrayTransforms
outputTransform FusedKer
ker) [VName]
names [Ident]
validents
finaliseSOAC :: SOAC.SOAC SOACS -> FusionGM (SOAC.SOAC SOACS)
finaliseSOAC :: SOAC -> FusionGM SOAC
finaliseSOAC SOAC
new_soac =
case SOAC
new_soac of
SOAC.Screma SubExp
w (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
arrs -> do
[Scan SOACS]
scans' <- [Scan SOACS]
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS])
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
scan_nes) -> do
Lambda SOACS
scan_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
scan_lam
Scan SOACS -> FusionGM (Scan SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scan SOACS -> FusionGM (Scan SOACS))
-> Scan SOACS -> FusionGM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam' [SubExp]
scan_nes
[Reduce SOACS]
reds' <- [Reduce SOACS]
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS])
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes) -> do
Lambda SOACS
red_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
red_lam
Reduce SOACS -> FusionGM (Reduce SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Reduce SOACS -> FusionGM (Reduce SOACS))
-> Reduce SOACS -> FusionGM (Reduce SOACS)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_lam' [SubExp]
red_nes
Lambda SOACS
map_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
map_lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam') [Input]
arrs
SOAC.Scatter SubExp
w Lambda SOACS
lam [Input]
inps [(Shape, Int, VName)]
dests -> do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' [Input]
inps [(Shape, Int, VName)]
dests
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam [Input]
arrs -> do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam' [Input]
arrs
SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
inps -> do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall rep.
SubExp
-> StreamForm rep -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam' [SubExp]
nes [Input]
inps
simplifyAndFuseInLambda :: Lambda -> FusionGM Lambda
simplifyAndFuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam = do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
(Names
_, FusedRes
nfres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
mkFreshFusionRes) Lambda SOACS
lam'
let nfres' :: FusedRes
nfres' = FusedRes -> FusedRes
cleanFusionResult FusedRes
nfres
FusedRes -> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
nfres' (FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS))
-> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda Lambda SOACS
lam'
copyNewlyConsumed ::
Names ->
Futhark.SOAC (Aliases.Aliases SOACS) ->
Builder SOACS (Futhark.SOAC SOACS)
copyNewlyConsumed :: Names
-> SOAC (Aliases SOACS)
-> BuilderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed Names
was_consumed SOAC (Aliases SOACS)
soac =
case SOAC (Aliases SOACS)
soac of
Futhark.Screma SubExp
w [VName]
arrs (Futhark.ScremaForm [Scan (Aliases SOACS)]
scans [Reduce (Aliases SOACS)]
reds Lambda (Aliases SOACS)
map_lam) -> do
[VName]
arrs' <- (VName -> BuilderT SOACS (State VNameSource) VName)
-> [VName] -> BuilderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BuilderT SOACS (State VNameSource) VName
copyConsumedArr [VName]
arrs
Lambda SOACS
map_lam' <- Lambda (Aliases (Rep (BuilderT SOACS (State VNameSource))))
-> BuilderT
SOACS
(State VNameSource)
(Lambda (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(CanBeAliased (Op (Rep m)), MonadBuilder m, Buildable (Rep m)) =>
Lambda (Aliases (Rep m)) -> m (Lambda (Rep m))
copyFreeInLambda Lambda (Aliases (Rep (BuilderT SOACS (State VNameSource))))
Lambda (Aliases SOACS)
map_lam
let scans' :: [Scan SOACS]
scans' =
(Scan (Aliases SOACS) -> Scan SOACS)
-> [Scan (Aliases SOACS)] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
( \Scan (Aliases SOACS)
scan ->
Scan (Aliases SOACS)
scan
{ scanLambda :: Lambda SOACS
scanLambda =
Lambda (Aliases SOACS) -> Lambda SOACS
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
Aliases.removeLambdaAliases
(Scan (Aliases SOACS) -> Lambda (Aliases SOACS)
forall rep. Scan rep -> Lambda rep
scanLambda Scan (Aliases SOACS)
scan)
}
)
[Scan (Aliases SOACS)]
scans
let reds' :: [Reduce SOACS]
reds' =
(Reduce (Aliases SOACS) -> Reduce SOACS)
-> [Reduce (Aliases SOACS)] -> [Reduce SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
( \Reduce (Aliases SOACS)
red ->
Reduce (Aliases SOACS)
red
{ redLambda :: Lambda SOACS
redLambda =
Lambda (Aliases SOACS) -> Lambda SOACS
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
Aliases.removeLambdaAliases
(Reduce (Aliases SOACS) -> Lambda (Aliases SOACS)
forall rep. Reduce rep -> Lambda rep
redLambda Reduce (Aliases SOACS)
red)
}
)
[Reduce (Aliases SOACS)]
reds
SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
Futhark.ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam'
SOAC (Aliases SOACS)
_ -> SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BuilderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SOAC SOACS) -> SOAC SOACS
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SOAC SOACS)
SOAC (Aliases SOACS)
soac
where
consumed :: Names
consumed = SOAC (Aliases SOACS) -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SOAC (Aliases SOACS)
soac
newly_consumed :: Names
newly_consumed = Names
consumed Names -> Names -> Names
`namesSubtract` Names
was_consumed
copyConsumedArr :: VName -> BuilderT SOACS (State VNameSource) VName
copyConsumedArr VName
a
| VName
a VName -> Names -> Bool
`nameIn` Names
newly_consumed =
String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName)
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp) -> BasicOp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
a
| Bool
otherwise = VName -> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
a
copyFreeInLambda :: Lambda (Aliases (Rep m)) -> m (Lambda (Rep m))
copyFreeInLambda Lambda (Aliases (Rep m))
lam = do
let free_consumed :: Names
free_consumed =
Lambda (Aliases (Rep m)) -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda (Aliases (Rep m))
lam
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Param (LParamInfo (Rep m)) -> VName)
-> [Param (LParamInfo (Rep m))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo (Rep m))] -> [VName])
-> [Param (LParamInfo (Rep m))] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases (Rep m)) -> [LParam (Aliases (Rep m))]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Aliases (Rep m))
lam)
(Seq (Stm (Rep m))
stms, Map VName VName
subst) <-
((Seq (Stm (Rep m)), Map VName VName)
-> VName -> m (Seq (Stm (Rep m)), Map VName VName))
-> (Seq (Stm (Rep m)), Map VName VName)
-> [VName]
-> m (Seq (Stm (Rep m)), Map VName VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Seq (Stm (Rep m)), Map VName VName)
-> VName -> m (Seq (Stm (Rep m)), Map VName VName)
forall (m :: * -> *).
MonadBuilder m =>
(Stms (Rep m), Map VName VName)
-> VName -> m (Stms (Rep m), Map VName VName)
copyFree (Seq (Stm (Rep m))
forall a. Monoid a => a
mempty, Map VName VName
forall a. Monoid a => a
mempty) ([VName] -> m (Seq (Stm (Rep m)), Map VName VName))
-> [VName] -> m (Seq (Stm (Rep m)), Map VName VName)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free_consumed
let lam' :: Lambda (Rep m)
lam' = Lambda (Aliases (Rep m)) -> Lambda (Rep m)
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
Aliases.removeLambdaAliases Lambda (Aliases (Rep m))
lam
Lambda (Rep m) -> m (Lambda (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Rep m) -> m (Lambda (Rep m)))
-> Lambda (Rep m) -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$
if Seq (Stm (Rep m)) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm (Rep m))
stms
then Lambda (Rep m)
lam'
else
Lambda (Rep m)
lam'
{ lambdaBody :: BodyT (Rep m)
lambdaBody =
Seq (Stm (Rep m)) -> BodyT (Rep m) -> BodyT (Rep m)
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Seq (Stm (Rep m))
stms (BodyT (Rep m) -> BodyT (Rep m)) -> BodyT (Rep m) -> BodyT (Rep m)
forall a b. (a -> b) -> a -> b
$
Map VName VName -> BodyT (Rep m) -> BodyT (Rep m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (BodyT (Rep m) -> BodyT (Rep m)) -> BodyT (Rep m) -> BodyT (Rep m)
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> BodyT (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
lam'
}
copyFree :: (Stms (Rep m), Map VName VName)
-> VName -> m (Stms (Rep m), Map VName VName)
copyFree (Stms (Rep m)
stms, Map VName VName
subst) VName
v = do
VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy"
Stm (Rep m)
copy <- [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName
v_copy] (Exp (Rep m) -> m (Stm (Rep m))) -> Exp (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
(Stms (Rep m), Map VName VName)
-> m (Stms (Rep m), Map VName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Rep m) -> Stms (Rep m)
forall rep. Stm rep -> Stms rep
oneStm Stm (Rep m)
copy Stms (Rep m) -> Stms (Rep m) -> Stms (Rep m)
forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
stms, VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v_copy Map VName VName
subst)
mkFreshFusionRes :: FusedRes
mkFreshFusionRes :: FusedRes
mkFreshFusionRes =
FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
{ rsucc :: Bool
rsucc = Bool
False,
outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
}
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
res1 FusedRes
res2 = do
let ufus_mres :: Names
ufus_mres = FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2
[VName]
inp_both <- [VName] -> FusionGM [VName]
expandSoacInpArr ([VName] -> FusionGM [VName]) -> [VName] -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1 Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.intersection` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2
let m_unfus :: Names
m_unfus = Names
ufus_mres Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_both)
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
(FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
(FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
Names
m_unfus
(FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)
getIdentArr :: [SOAC.Input] -> ([VName], [VName])
getIdentArr :: [Input] -> ([VName], [VName])
getIdentArr = (([VName], [VName]) -> Input -> ([VName], [VName]))
-> ([VName], [VName]) -> [Input] -> ([VName], [VName])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([], [])
where
comb :: ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([VName]
vs, [VName]
os) (SOAC.Input ArrayTransforms
ts VName
idd Type
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = (VName
idd VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs, [VName]
os)
comb ([VName]
vs, [VName]
os) Input
inp =
([VName]
vs, Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
os)
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult FusedRes
fres =
let newks :: Map KernName FusedKer
newks = (FusedKer -> Bool)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not (Bool -> Bool) -> (FusedKer -> Bool) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> (FusedKer -> [VName]) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [VName]
fusedVars) (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres)
newoa :: Map VName KernName
newoa = (KernName -> Bool) -> Map VName KernName -> Map VName KernName
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks) (FusedRes -> Map VName KernName
outArr FusedRes
fres)
newia :: Map VName (Set KernName)
newia = (Set KernName -> Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((KernName -> Bool) -> Set KernName -> Set KernName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks)) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
fres)
in FusedRes
fres {outArr :: Map VName KernName
outArr = Map VName KernName
newoa, inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
newia, kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
newks}
errorIllegal :: String -> FusionGM FusedRes
errorIllegal :: String -> FusionGM FusedRes
errorIllegal String
soac_name =
Error -> FusionGM FusedRes
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedRes) -> Error -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
(String
"In Fusion.hs, soac " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
soac_name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" appears illegally in pgm!")