{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion.TryFusion
( FusedSOAC (..),
Mode (..),
attemptFusion,
)
where
import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.MapNest qualified as MapNest
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)
newtype TryFusion a
= TryFusion
( ReaderT
(Scope SOACS)
(StateT VNameSource Maybe)
a
)
deriving
( forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
Functor TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: forall a. a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
Applicative,
Applicative TryFusion
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: forall a. TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: forall a. TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: forall a. TryFusion a
$cempty :: forall a. TryFusion a
Alternative,
Applicative TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
Monad,
Monad TryFusion
forall a. [Char] -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. [Char] -> m a) -> MonadFail m
fail :: forall a. [Char] -> TryFusion a
$cfail :: forall a. [Char] -> TryFusion a
MonadFail,
Monad TryFusion
TryFusion VNameSource
VNameSource -> TryFusion ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
tryFusion ::
MonadFreshNames m =>
TryFusion a ->
Scope SOACS ->
m (Maybe a)
tryFusion :: forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
case forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
Just (a
x, VNameSource
src') -> (forall a. a -> Maybe a
Just a
x, VNameSource
src')
Maybe (a, VNameSource)
Nothing -> (forall a. Maybe a
Nothing, VNameSource
src)
liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: forall a. Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Nothing"
liftMaybe (Just a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
type SOAC = SOAC.SOAC SOACS
type MapNest = MapNest.MapNest SOACS
inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia TypeBase Shape NoUniqueness
iat) =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts' VName
ia TypeBase Shape NoUniqueness
iat)
ViewF
SOAC.EmptyF -> forall a. Maybe a
Nothing
data FusedSOAC = FusedSOAC
{
FusedSOAC -> SOAC SOACS
fsSOAC :: SOAC,
FusedSOAC -> ArrayTransforms
fsOutputTransform :: SOAC.ArrayTransforms,
FusedSOAC -> [VName]
fsOutNames :: [VName]
}
deriving (Int -> FusedSOAC -> ShowS
[FusedSOAC] -> ShowS
FusedSOAC -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [FusedSOAC] -> ShowS
$cshowList :: [FusedSOAC] -> ShowS
show :: FusedSOAC -> [Char]
$cshow :: FusedSOAC -> [Char]
showsPrec :: Int -> FusedSOAC -> ShowS
$cshowsPrec :: Int -> FusedSOAC -> ShowS
Show)
inputs :: FusedSOAC -> [SOAC.Input]
inputs :: FusedSOAC -> [Input]
inputs = forall rep. SOAC rep -> [Input]
SOAC.inputs forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedSOAC -> SOAC SOACS
fsSOAC
setInputs :: [SOAC.Input] -> FusedSOAC -> FusedSOAC
setInputs :: [Input] -> FusedSOAC -> FusedSOAC
setInputs [Input]
inps FusedSOAC
ker = FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}
tryOptimizeSOAC ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
tryOptimizeSOAC :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeSOAC Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
(SOAC SOACS
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC forall a. Maybe a
Nothing SOAC SOACS
soac forall a. Monoid a => a
mempty
let ker' :: FusedSOAC
ker' = forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) [Input] -> FusedSOAC -> FusedSOAC
`setInputs` FusedSOAC
ker
outIdents :: [Ident]
outIdents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars forall a b. (a -> b) -> a -> b
$ forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
where
addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
| Input -> VName
SOAC.inputArray Input
inp forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
| Bool
otherwise =
Input
inp
tryOptimizeKernel ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
tryOptimizeKernel :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeKernel Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
FusedSOAC
ker' <- Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel (forall a. a -> Maybe a
Just [VName]
outVars) FusedSOAC
ker
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'
tryExposeInputs ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
tryExposeInputs :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryExposeInputs Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
(FusedSOAC
ker', ArrayTransforms
ots) <- [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
outVars FusedSOAC
ker
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'
else do
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Names
unfus_nms forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
(SOAC SOACS
soac', ArrayTransforms
ots') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac ArrayTransforms
ots
let outIdents :: [Ident]
outIdents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars forall a b. (a -> b) -> a -> b
$ forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"tryExposeInputs could not pull SOAC transforms"
fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker =
FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS -> SOAC SOACS
fixInputTypes' forall a b. (a -> b) -> a -> b
$ FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}
where
fixInputTypes' :: SOAC SOACS -> SOAC SOACS
fixInputTypes' SOAC SOACS
soac =
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac) forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC SOACS
soac
fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
_)
| Just Ident
v' <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts VName
v forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase Shape NoUniqueness
identType Ident
v'
fixInputType Input
inp = Input
inp
applyFusionRules ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
applyFusionRules :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker =
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeSOAC Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeKernel Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryExposeInputs Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
data Mode = Horizontal | Vertical
attemptFusion ::
(HasScope SOACS m, MonadFreshNames m) =>
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
m (Maybe FusedSOAC)
attemptFusion :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
attemptFusion Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
Scope SOACS
scope <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker) Scope SOACS
scope
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK ([VName]
nonmap_outs, [VName]
_map_outs) FusedSOAC
ker =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
nonmap_outs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
fuseSOACwithKer ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
fuseSOACwithKer :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_set [VName]
outVars SOAC SOACS
soac_p FusedSOAC
ker = do
let soac_c :: SOAC SOACS
soac_c = FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker
inp_p_arr :: [Input]
inp_p_arr = forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_p
inp_c_arr :: [Input]
inp_c_arr = forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_c
lam_p :: Lambda SOACS
lam_p = forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_p
lam_c :: Lambda SOACS
lam_c = forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_c
w :: SubExp
w = forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p
returned_outvars :: [VName]
returned_outvars = forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
success :: [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_outnms SOAC SOACS
res_soac = do
Lambda SOACS
uniq_lam <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
res_soac
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
FusedSOAC
ker
{ fsSOAC :: SOAC SOACS
fsSOAC = Lambda SOACS
uniq_lam forall rep. Lambda rep -> SOAC rep -> SOAC rep
`SOAC.setLambda` SOAC SOACS
res_soac,
fsOutNames :: [VName]
fsOutNames = [VName]
res_outnms
}
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p forall a. Eq a => a -> a -> Bool
== forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c
let ker_inputs :: [VName]
ker_inputs = forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
okInput :: VName -> Input -> Bool
okInput VName
v Input
inp = VName
v forall a. Eq a => a -> a -> Bool
/= Input -> VName
SOAC.inputArray Input
inp Bool -> Bool -> Bool
|| forall a. Maybe a -> Bool
isJust (Input -> Maybe VName
SOAC.isVarishInput Input
inp)
inputOrUnfus :: VName -> Bool
inputOrUnfus VName
v = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Input -> Bool
okInput VName
v) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) Bool -> Bool -> Bool
|| VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
ker_inputs
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inputOrUnfus [VName]
outVars
[(VName, Ident)]
outPairs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p) forall a b. (a -> b) -> a -> b
$ \(VName
outVar, TypeBase Shape NoUniqueness
t) -> do
VName
outVar' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
outVar forall a. [a] -> [a] -> [a]
++ [Char]
"_elem"
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
outVar, VName -> TypeBase Shape NoUniqueness -> Ident
Ident VName
outVar' TypeBase Shape NoUniqueness
t)
let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
let (Lambda SOACS
res_lam, [Input]
new_inp) = forall rep.
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
([VName]
extra_nms, [TypeBase Shape NoUniqueness]
extra_rtps) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1) forall a b. (a -> b) -> a -> b
$
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p
res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
res_lam forall a. [a] -> [a] -> [a]
++ [TypeBase Shape NoUniqueness]
extra_rtps}
in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)
case (SOAC SOACS
soac_c, SOAC SOACS
soac_p, Mode
mode) of
(SOAC SOACS, SOAC SOACS, Mode)
_ | forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p forall a. Eq a => a -> a -> Bool
/= forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC widths must match."
(SOAC SOACS
_, SOAC SOACS
_, Mode
Horizontal)
| Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker) ->
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Horizontal fusion is invalid in the presence of output transforms."
(SOAC SOACS
_, SOAC SOACS
_, Mode
Vertical)
| Names
unfus_set forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty,
Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker) ->
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot perform diagonal fusion in the presence of output transforms."
( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_,
Mode
_
)
| ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK (forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_p forall a. Num a => a -> a -> a
+ forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedSOAC
ker -> do
let red_nes_p :: [SubExp]
red_nes_p = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
red_nes_c :: [SubExp]
red_nes_c = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
scan_nes_p :: [SubExp]
scan_nes_p = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
scan_nes_c :: [SubExp]
scan_nes_c = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
(Lambda SOACS
res_lam', [Input]
new_inp) =
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam_p
[SubExp]
scan_nes_p
[SubExp]
red_nes_p
[Input]
inp_p_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam_c
[SubExp]
scan_nes_c
[SubExp]
red_nes_c
[Input]
inp_c_arr
([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker
unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success
( [VName]
soac_p_scanout
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
)
forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma
SubExp
w
(forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm ([Scan SOACS]
scans_p forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
[Input]
new_inp
( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_,
Mode
_
)
| forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
[VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_,
Mode
_
)
| forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
[VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_,
Mode
Horizontal
) -> do
let p_num_buckets :: Int
p_num_buckets = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
c_num_buckets :: Int
c_num_buckets = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
(Body SOACS
body_p, Body SOACS
body_c) = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
body' :: Body SOACS
body' =
Body
{ bodyDec :: BodyDec SOACS
bodyDec = forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_p forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_c,
bodyResult :: Result
bodyResult =
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (forall rep. Body rep -> Result
bodyResult Body SOACS
body_c)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take Int
p_num_buckets (forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (forall rep. Body rep -> Result
bodyResult Body SOACS
body_c)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
}
lam' :: Lambda SOACS
lam' =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType =
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c)
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)
( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
SOAC.Scatter
SubExp
_len_p
Lambda SOACS
_lam_p
[Input]
ivs_p
[(Shape, Int, VName)]
as_p,
Mode
Horizontal
) -> do
let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices forall a. [a] -> [a] -> [a]
++ [a]
ys_indices forall a. [a] -> [a] -> [a]
++ [a]
xs_vals forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
where
([a]
xs_indices, [a]
xs_vals) = forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
([a]
ys_indices, [a]
ys_vals) = forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
let (Body SOACS
body_p, Body SOACS
body_c) = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
let body' :: Body SOACS
body' =
Body
{ bodyDec :: BodyDec SOACS
bodyDec = forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_p forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_c,
bodyResult :: Result
bodyResult = forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (forall rep. Body rep -> Result
bodyResult Body SOACS
body_c) [(Shape, Int, VName)]
as_p (forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
}
let lam' :: Lambda SOACS
lam' =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
(SOAC.Scatter {}, SOAC SOACS
_, Mode
_) ->
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a scatter with anything else than a scatter or a map"
(SOAC SOACS
_, SOAC.Scatter {}, Mode
_) ->
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a scatter with anything else than a scatter or a map"
(SOAC.Stream {}, SOAC.Stream {}, Mode
_) -> do
([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_nms SOAC SOACS
res_stream
(SOAC.Stream {}, SOAC SOACS
_, Mode
_) -> do
(SOAC SOACS
soac_p', [Ident]
newacc_ids) <- forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
Mode
mode
([VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
(forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
SOAC SOACS
soac_p'
FusedSOAC
ker
(SOAC SOACS
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_, Mode
_) | Just ([Scan SOACS], Lambda SOACS)
_ <- forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form -> do
(SOAC SOACS
soac_p', [Ident]
newacc_ids) <- forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
if SOAC SOACS
soac_p' forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_p
then
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
Mode
mode
([VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
(forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
SOAC SOACS
soac_p'
FusedSOAC
ker
else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."
(SOAC SOACS
_, SOAC.Stream {}, Mode
_) -> do
(SOAC SOACS
soac_c', [Ident]
newacc_ids) <- forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_c
if SOAC SOACS
soac_c' forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_c
then
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
Mode
mode
([VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
[VName]
outVars
SOAC SOACS
soac_p
forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac_c', fsOutNames :: [VName]
fsOutNames = forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids forall a. [a] -> [a] -> [a]
++ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker}
else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."
(SOAC SOACS, SOAC SOACS, Mode)
_ -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse"
fuseStreamHelper ::
[VName] ->
Names ->
[VName] ->
[(VName, Ident)] ->
SOAC ->
SOAC ->
TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper
[VName]
out_kernms
Names
unfus_set
[VName]
outVars
[(VName, Ident)]
outPairs
(SOAC.Stream SubExp
w2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
(SOAC.Stream SubExp
_ Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) = do
let chunk1 :: Param (TypeBase Shape NoUniqueness)
chunk1 = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1
chunk2 :: Param (TypeBase Shape NoUniqueness)
chunk2 = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2
hmnms :: Map VName VName
hmnms = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk2, forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk1)]
lam20 :: Lambda SOACS
lam20 = forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1}
lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam20}
(Lambda SOACS
res_lam', [Input]
new_inp) =
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam1'
[]
[SubExp]
nes1
[Input]
inp1_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam2'
[]
[SubExp]
nes2
[Input]
inp2_arr
res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param (TypeBase Shape NoUniqueness)
chunk1 forall a. a -> [a] -> [a]
: forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
res_lam'}
unfus_accs :: [VName]
unfus_accs = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
unfus_arrs :: [VName]
unfus_arrs = forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
unfus_accs) forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [VName]
unfus_accs forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w2 Lambda SOACS
res_lam'' ([SubExp]
nes1 forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
)
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC SOACS
_ SOAC SOACS
_ = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot Fuse Streams!"
optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel Maybe [VName]
inp FusedSOAC
ker = do
(SOAC SOACS
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) (FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac, fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
resTrans}
optimizeSOAC ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC SOACS
soac ArrayTransforms
os = do
(Bool, SOAC SOACS, ArrayTransforms)
res <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
False, SOAC SOACS
soac, ArrayTransforms
os) [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations
case (Bool, SOAC SOACS, ArrayTransforms)
res of
(Bool
False, SOAC SOACS
_, ArrayTransforms
_) -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"No optimisation applied"
(Bool
True, SOAC SOACS
soac', ArrayTransforms
os') -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
os')
where
comb :: (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f =
do
(SOAC SOACS
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f Maybe [VName]
inp SOAC SOACS
soac' ArrayTransforms
os
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
True, SOAC SOACS
soac'', ArrayTransforms
os'')
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os')
type Optimization =
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizations :: [Optimization]
optimizations :: [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim]
iswim ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
| Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- forall rep. ScremaForm rep -> Maybe [Scan rep]
Futhark.isScanSOAC ScremaForm SOACS
form,
Just (Pat (TypeBase Shape NoUniqueness)
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS
-> Maybe
(Pat (TypeBase Shape NoUniqueness), Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
Just [VName]
nes_names <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
let nes_idents :: [Ident]
nes_idents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
nes_names forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
map_nes :: [Input]
map_nes = forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
map_arrs' :: [Input]
map_arrs' = [Input]
map_nes forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
([Param (TypeBase Shape NoUniqueness)]
scan_acc_params, [Param (TypeBase Shape NoUniqueness)]
scan_elem_params) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
scan_fun
map_params :: [Param (TypeBase Shape NoUniqueness)]
map_params =
forall a b. (a -> b) -> [a] -> [b]
map LParam SOACS -> LParam SOACS
removeParamOuterDim [Param (TypeBase Shape NoUniqueness)]
scan_acc_params
forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param (TypeBase Shape NoUniqueness)]
scan_elem_params
map_rettype :: [TypeBase Shape NoUniqueness]
map_rettype = forall a b. (a -> b) -> [a] -> [b]
map (forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
scan_params :: [LParam SOACS]
scan_params = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
scan_body :: Body SOACS
scan_body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
scan_rettype :: [TypeBase Shape NoUniqueness]
scan_rettype = forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun
scan_fun' :: Lambda SOACS
scan_fun' = forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [LParam SOACS]
scan_params Body SOACS
scan_body [TypeBase Shape NoUniqueness]
scan_rettype
nes' :: [SubExp]
nes' = forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params
arrs' :: [VName]
arrs' = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params
ScremaForm SOACS
scan_form <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']
let map_body :: Body SOACS
map_body =
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
( forall rep. Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w Pat (TypeBase Shape NoUniqueness)
map_pat) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
)
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (TypeBase Shape NoUniqueness)
map_pat
map_fun' :: Lambda SOACS
map_fun' = forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
map_params Body SOACS
map_body [TypeBase Shape NoUniqueness]
map_rettype
perm :: [Int]
perm = case forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun of
[] -> []
TypeBase Shape NoUniqueness
t : [TypeBase Shape NoUniqueness]
_ -> Int
1 forall a. a -> [a] -> [a]
: Int
0 forall a. a -> [a] -> [a]
: [Int
2 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
map_w (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
map_cs [Int]
perm
)
iswim Maybe [VName]
_ SOAC SOACS
_ ArrayTransforms
_ =
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"ISWIM does not apply."
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
let t :: TypeBase Shape NoUniqueness
t = forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam SOACS
param
in LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
let t :: TypeBase Shape NoUniqueness
t = forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam SOACS
param forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
in LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}
setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo :: SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)
commonTransforms ::
[VName] ->
[SOAC.Input] ->
(SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
where
inps' :: [(Bool, Input)]
inps' =
[ (Input -> VName
SOAC.inputArray Input
inp forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
| Input
inp <- [Input]
inps
]
commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
case forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot SOAC.<|) forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
where
inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
(Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> forall a. a -> Maybe a
Just (forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
| ArrayTransform
ot1 forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> forall a. a -> Maybe a
Just (forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> forall a. Maybe a
Nothing
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
forall a. Ord a => a -> a -> a
min Int
resDims (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) forall a. Num a => a -> a -> a
+ Int
1
where
resDims :: Int
resDims = forall {shape} {u}. ArrayShape shape => [TypeBase shape u] -> Int
minDim forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
[] -> forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
Nesting SOACS
nest : [Nesting SOACS]
_ -> forall {k} (rep :: k). Nesting rep -> [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType Nesting SOACS
nest
minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Ord a => a -> a -> a
min (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts
pullRearrange ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange SOAC SOACS
soac ArrayTransforms
ots = do
MapNest
nest <- forall a. Maybe a -> TryFusion a
liftMaybe forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
SOAC.Rearrange Certs
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
if [Int] -> Int
rearrangeReach [Int]
perm forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let
perm' :: Input -> [Int]
perm' Input
inp = forall a. Int -> [a] -> [a]
take Int
r [Int]
perm forall a. [a] -> [a] -> [a]
++ [forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r forall a. Num a => a -> a -> a
- Int
1]
where
r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
cs forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
inputs' :: [Input]
inputs' = forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm forall a b. (a -> b) -> a -> b
$ forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
SOAC SOACS
soac' <-
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC forall a b. (a -> b) -> a -> b
$
[Input]
inputs' forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull transpose"
pushRearrange ::
[VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC SOACS
soac ArrayTransforms
ots = do
MapNest
nest <- forall a. Maybe a -> TryFusion a
liftMaybe forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
([Int]
perm, [Input]
inputs') <- forall a. Maybe a -> TryFusion a
liftMaybe forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds forall a b. (a -> b) -> a -> b
$ forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
if [Int] -> Int
rearrangeReach [Int]
perm forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let invertRearrange :: ArrayTransform
invertRearrange = Certs -> [Int] -> ArrayTransform
SOAC.Rearrange forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
SOAC SOACS
soac' <-
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC forall a b. (a -> b) -> a -> b
$
[Input]
inputs'
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot push transpose"
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest.MapNest
SubExp
w
Lambda SOACS
body
( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
forall {k} {k} {rep :: k} {rep :: k}.
Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType
[Nesting SOACS]
nestings
forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
1
forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
ts
)
[Input]
inps
where
origts :: [TypeBase Shape NoUniqueness]
origts = forall rep. MapNest rep -> [TypeBase Shape NoUniqueness]
MapNest.typeOf MapNest
nest
rearrangeType' :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' TypeBase Shape NoUniqueness
t = [Int] -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType (forall a. Int -> [a] -> [a]
take (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) [Int]
perm) TypeBase Shape NoUniqueness
t
ts :: [TypeBase Shape NoUniqueness]
ts = forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' [TypeBase Shape NoUniqueness]
origts
setReturnType :: Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType Nesting rep
nesting [TypeBase Shape NoUniqueness]
t' =
Nesting rep
nesting {nestingReturnType :: [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType = [TypeBase Shape NoUniqueness]
t'}
fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
case forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
[Int]
perm : [[Int]]
_ -> do
[Input]
inps' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [Input]
inps')
[[Int]]
_ -> forall a. Maybe a
Nothing
where
exposable :: Input -> Bool
exposable = (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray
inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
| ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certs
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = forall a. a -> Maybe a
Just [Int]
perm
inputRearrange Input
_ = forall a. Maybe a
Nothing
fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
| Int
r <- Input -> Int
SOAC.inputRank Input
inp,
Int
r forall a. Ord a => a -> a -> Bool
>= Int
d =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
| Bool
otherwise = forall a. Maybe a
Nothing
pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
| Just Lambda SOACS
maplam <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam = do
let mapw' :: SubExp
mapw' = case forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
[] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
SubExp
d : [SubExp]
_ -> SubExp
d
trInput :: Input -> Input
trInput Input
inp
| forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Input -> TypeBase Shape NoUniqueness
SOAC.inputType Input
inp) forall a. Eq a => a -> a -> Bool
== Int
1 =
ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape) Input
inp
| Bool
otherwise =
ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) Input
inp
inputs' :: [Input]
inputs' = forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
trInput [Input]
inps
inputTypes :: [TypeBase Shape NoUniqueness]
inputTypes = forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase Shape NoUniqueness
SOAC.inputType [Input]
inputs'
let outersoac ::
([SOAC.Input] -> SOAC) ->
(SubExp, [SubExp]) ->
TryFusion ([SOAC.Input] -> SOAC)
outersoac :: ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac [Input] -> SOAC SOACS
inner (SubExp
w, [SubExp]
outershape) = do
let addDims :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims TypeBase Shape NoUniqueness
t = forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape NoUniqueness
t (forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
retTypes :: [TypeBase Shape NoUniqueness]
retTypes = forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam
[Param (TypeBase Shape NoUniqueness)]
ps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
inputTypes forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
inpt ->
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"pullReshape_param" forall a b. (a -> b) -> a -> b
$
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shape forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) TypeBase Shape NoUniqueness
inpt
Body SOACS
inner_body <-
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC SOACS
inner forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param (TypeBase Shape NoUniqueness)]
ps]
let inner_fun :: Lambda SOACS
inner_fun =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
ps,
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
retTypes,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
inner_body
}
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
inner_fun
[Input] -> SOAC SOACS
op' <-
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac (forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
mapw' forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
maplam) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape) forall a b. (a -> b) -> a -> b
$
forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [[a]]
tails forall a b. (a -> b) -> a -> b
$
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Input] -> SOAC SOACS
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC SOACS
_ ArrayTransforms
_ = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull reshape"
exposeInputs ::
[VName] ->
FusedSOAC ->
TryFusion (FusedSOAC, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
inpIds FusedSOAC
ker =
(FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pushRearrange')
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pullRearrange')
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker
where
ot :: ArrayTransforms
ot = FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker
pushRearrange' :: TryFusion FusedSOAC
pushRearrange' = do
(SOAC SOACS
soac', ArrayTransforms
ot') <- [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
forall (f :: * -> *) a. Applicative f => a -> f a
pure
FusedSOAC
ker
{ fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
ot'
}
pullRearrange' :: TryFusion FusedSOAC
pullRearrange' = do
(SOAC SOACS
soac', ArrayTransforms
ot') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"pullRearrange was not enough"
forall (f :: * -> *) a. Applicative f => a -> f a
pure
FusedSOAC
ker
{ fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
SOAC.noTransforms
}
exposeInputs' :: FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker' =
case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [Input]
inputs FusedSOAC
ker' of
(ArrayTransforms
ot', [Input]
inps')
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC
ker' {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps' forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker'}, ArrayTransforms
ot')
(ArrayTransforms, [Input])
_ -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot expose"
exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds
outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange, SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape]
pullOutputTransforms ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms = forall {t} {t}.
[t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers
where
attempt :: [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [] t
_ t
_ = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull anything"
attempt (t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps) t
soac t
ots =
do
(SOAC SOACS
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p t
soac t
ots
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
SOAC.noTransforms)
else SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac' ArrayTransforms
ots' forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps t
soac t
ots