{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion.TryFusion
( FusedSOAC (..),
Mode (..),
attemptFusion,
)
where
import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.MapNest qualified as MapNest
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)
newtype TryFusion a
= TryFusion
( ReaderT
(Scope SOACS)
(StateT VNameSource Maybe)
a
)
deriving
( (forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
fmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
<$ :: forall a b. a -> TryFusion b -> TryFusion a
Functor,
Functor TryFusion
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> TryFusion a
pure :: forall a. a -> TryFusion a
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
liftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
Applicative,
Applicative TryFusion
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
$cempty :: forall a. TryFusion a
empty :: forall a. TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
$csome :: forall a. TryFusion a -> TryFusion [a]
some :: forall a. TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
many :: forall a. TryFusion a -> TryFusion [a]
Alternative,
Applicative TryFusion
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$creturn :: forall a. a -> TryFusion a
return :: forall a. a -> TryFusion a
Monad,
Monad TryFusion
Monad TryFusion
-> (forall a. [Char] -> TryFusion a) -> MonadFail TryFusion
forall a. [Char] -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. [Char] -> m a) -> MonadFail m
$cfail :: forall a. [Char] -> TryFusion a
fail :: forall a. [Char] -> TryFusion a
MonadFail,
Monad TryFusion
TryFusion VNameSource
Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: TryFusion VNameSource
getNameSource :: TryFusion VNameSource
$cputNameSource :: VNameSource -> TryFusion ()
putNameSource :: VNameSource -> TryFusion ()
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
tryFusion ::
(MonadFreshNames m) =>
TryFusion a ->
Scope SOACS ->
m (Maybe a)
tryFusion :: forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)
liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: forall a. Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = [Char] -> TryFusion a
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
type SOAC = SOAC.SOAC SOACS
type MapNest = MapNest.MapNest SOACS
inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia TypeBase Shape NoUniqueness
iat) =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts' VName
ia TypeBase Shape NoUniqueness
iat)
ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing
data FusedSOAC = FusedSOAC
{
FusedSOAC -> SOAC SOACS
fsSOAC :: SOAC,
FusedSOAC -> ArrayTransforms
fsOutputTransform :: SOAC.ArrayTransforms,
FusedSOAC -> [VName]
fsOutNames :: [VName]
}
deriving (Int -> FusedSOAC -> ShowS
[FusedSOAC] -> ShowS
FusedSOAC -> [Char]
(Int -> FusedSOAC -> ShowS)
-> (FusedSOAC -> [Char])
-> ([FusedSOAC] -> ShowS)
-> Show FusedSOAC
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FusedSOAC -> ShowS
showsPrec :: Int -> FusedSOAC -> ShowS
$cshow :: FusedSOAC -> [Char]
show :: FusedSOAC -> [Char]
$cshowList :: [FusedSOAC] -> ShowS
showList :: [FusedSOAC] -> ShowS
Show)
inputs :: FusedSOAC -> [SOAC.Input]
inputs :: FusedSOAC -> [Input]
inputs = SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs (SOAC SOACS -> [Input])
-> (FusedSOAC -> SOAC SOACS) -> FusedSOAC -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedSOAC -> SOAC SOACS
fsSOAC
setInputs :: [SOAC.Input] -> FusedSOAC -> FusedSOAC
setInputs :: [Input] -> FusedSOAC -> FusedSOAC
setInputs [Input]
inps FusedSOAC
ker = FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}
tryOptimizeSOAC ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
tryOptimizeSOAC :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeSOAC Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
(SOAC SOACS
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC SOACS
soac ArrayTransforms
forall a. Monoid a => a
mempty
let ker' :: FusedSOAC
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) [Input] -> FusedSOAC -> FusedSOAC
`setInputs` FusedSOAC
ker
outIdents :: [Ident]
outIdents = (VName -> TypeBase Shape NoUniqueness -> Ident)
-> [VName] -> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase Shape NoUniqueness] -> [Ident])
-> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
where
addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
| Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
| Bool
otherwise =
Input
inp
tryOptimizeKernel ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
tryOptimizeKernel :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeKernel Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
FusedSOAC
ker' <- Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedSOAC
ker
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'
tryExposeInputs ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
tryExposeInputs :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryExposeInputs Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
(FusedSOAC
ker', ArrayTransforms
ots) <- [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
outVars FusedSOAC
ker
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'
else do
Bool -> TryFusion ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> TryFusion ()) -> Bool -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ Names
unfus_nms Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
(SOAC SOACS
soac', ArrayTransforms
ots') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac ArrayTransforms
ots
let outIdents :: [Ident]
outIdents = (VName -> TypeBase Shape NoUniqueness -> Ident)
-> [VName] -> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase Shape NoUniqueness] -> [Ident])
-> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
else [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"tryExposeInputs could not pull SOAC transforms"
fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker =
FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS -> SOAC SOACS
fixInputTypes' (SOAC SOACS -> SOAC SOACS) -> SOAC SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}
where
fixInputTypes' :: SOAC SOACS -> SOAC SOACS
fixInputTypes' SOAC SOACS
soac =
(Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac) [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC SOACS
soac
fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
_)
| Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts VName
v (TypeBase Shape NoUniqueness -> Input)
-> TypeBase Shape NoUniqueness -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase Shape NoUniqueness
identType Ident
v'
fixInputType Input
inp = Input
inp
applyFusionRules ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
applyFusionRules :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker =
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeSOAC Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
TryFusion FusedSOAC -> TryFusion FusedSOAC -> TryFusion FusedSOAC
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeKernel Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
TryFusion FusedSOAC -> TryFusion FusedSOAC -> TryFusion FusedSOAC
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
TryFusion FusedSOAC -> TryFusion FusedSOAC -> TryFusion FusedSOAC
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryExposeInputs Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
data Mode = Horizontal | Vertical
attemptFusion ::
(HasScope SOACS m, MonadFreshNames m) =>
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
m (Maybe FusedSOAC)
attemptFusion :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
attemptFusion Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
TryFusion FusedSOAC -> Scope SOACS -> m (Maybe FusedSOAC)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker) Scope SOACS
scope
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK ([VName]
nonmap_outs, [VName]
_map_outs) FusedSOAC
ker =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
nonmap_outs) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
fuseSOACwithKer ::
Mode ->
Names ->
[VName] ->
SOAC ->
FusedSOAC ->
TryFusion FusedSOAC
fuseSOACwithKer :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_set [VName]
outVars SOAC SOACS
soac_p FusedSOAC
ker = do
let soac_c :: SOAC SOACS
soac_c = FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker
inp_p_arr :: [Input]
inp_p_arr = SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_p
inp_c_arr :: [Input]
inp_c_arr = SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_c
lam_p :: Lambda SOACS
lam_p = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_p
lam_c :: Lambda SOACS
lam_c = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_c
w :: SubExp
w = SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p
returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
success :: [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_outnms SOAC SOACS
res_soac = do
Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
res_soac
FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC -> TryFusion FusedSOAC)
-> FusedSOAC -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
FusedSOAC
ker
{ fsSOAC :: SOAC SOACS
fsSOAC = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
`SOAC.setLambda` SOAC SOACS
res_soac,
fsOutNames :: [VName]
fsOutNames = [VName]
res_outnms
}
Bool -> TryFusion ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> TryFusion ()) -> Bool -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c
let ker_inputs :: [VName]
ker_inputs = (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
okInput :: VName -> Input -> Bool
okInput VName
v Input
inp = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= Input -> VName
SOAC.inputArray Input
inp Bool -> Bool -> Bool
|| Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Input -> Maybe VName
SOAC.isVarishInput Input
inp)
inputOrUnfus :: VName -> Bool
inputOrUnfus VName
v = (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Input -> Bool
okInput VName
v) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) Bool -> Bool -> Bool
|| VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
ker_inputs
Bool -> TryFusion ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> TryFusion ()) -> Bool -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inputOrUnfus [VName]
outVars
[(VName, Ident)]
outPairs <- [(VName, TypeBase Shape NoUniqueness)]
-> ((VName, TypeBase Shape NoUniqueness)
-> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)])
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p) (((VName, TypeBase Shape NoUniqueness) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)])
-> ((VName, TypeBase Shape NoUniqueness)
-> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, TypeBase Shape NoUniqueness
t) -> do
VName
outVar' <- [Char] -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> TryFusion VName) -> [Char] -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
outVar [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_elem"
(VName, Ident) -> TryFusion (VName, Ident)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
outVar, VName -> TypeBase Shape NoUniqueness -> Ident
Ident VName
outVar' TypeBase Shape NoUniqueness
t)
let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
([VName]
extra_nms, [TypeBase Shape NoUniqueness]
extra_rtps) =
[(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness]))
-> [(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$
((VName, TypeBase Shape NoUniqueness) -> Bool)
-> [(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, TypeBase Shape NoUniqueness) -> VName)
-> (VName, TypeBase Shape NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, TypeBase Shape NoUniqueness) -> VName
forall a b. (a, b) -> a
fst) ([(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)])
-> [(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
[VName]
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)])
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
(TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$
SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p
res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
res_lam [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase Shape NoUniqueness]
extra_rtps}
in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)
case (SOAC SOACS
soac_c, SOAC SOACS
soac_p, Mode
mode) of
(SOAC SOACS, SOAC SOACS, Mode)
_ | SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c -> [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC widths must match."
(SOAC SOACS
_, SOAC SOACS
_, Mode
Horizontal)
| Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker) ->
[Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Horizontal fusion is invalid in the presence of output transforms."
(SOAC SOACS
_, SOAC SOACS
_, Mode
Vertical)
| Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty,
Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker) ->
[Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot perform diagonal fusion in the presence of output transforms."
( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_,
Mode
_
)
| ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK (Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedSOAC
ker -> do
let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam_p
[SubExp]
scan_nes_p
[SubExp]
red_nes_p
[Input]
inp_p_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam_c
[SubExp]
scan_nes_c
[SubExp]
red_nes_c
[Input]
inp_c_arr
([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker
unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success
( [VName]
soac_p_scanout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
)
(SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma
SubExp
w
([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
[Input]
new_inp
( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_,
Mode
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
[VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
SubExp
-> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_,
Mode
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
[VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_,
Mode
Horizontal
) -> do
let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
(Body SOACS
body_p, Body SOACS
body_c) = (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
body' :: Body SOACS
body' =
Body
{ bodyDec :: BodyDec SOACS
bodyDec = Body SOACS -> BodyDec SOACS
forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_c,
bodyResult :: Result
bodyResult =
Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_c)
Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_c)
Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
}
lam' :: Lambda SOACS
lam' =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType =
Int -> TypeBase Shape NoUniqueness -> [TypeBase Shape NoUniqueness]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c)
[TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)
( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
SOAC.Scatter
SubExp
_len_p
Lambda SOACS
_lam_p
[Input]
ivs_p
[(Shape, Int, VName)]
as_p,
Mode
Horizontal
) -> do
let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
where
([a]
xs_indices, [a]
xs_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
([a]
ys_indices, [a]
ys_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
let (Body SOACS
body_p, Body SOACS
body_c) = (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
let body' :: Body SOACS
body' =
Body
{ bodyDec :: BodyDec SOACS
bodyDec = Body SOACS -> BodyDec SOACS
forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_c,
bodyResult :: Result
bodyResult = [(Shape, Int, VName)]
-> Result -> [(Shape, Int, VName)] -> Result -> Result
forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_c) [(Shape, Int, VName)]
as_p (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
}
let lam' :: Lambda SOACS
lam' =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [(Shape, Int, VName)]
-> [TypeBase Shape NoUniqueness]
-> [(Shape, Int, VName)]
-> [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness]
forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
SubExp
-> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c [(Shape, Int, VName)]
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
(SOAC.Scatter {}, SOAC SOACS
_, Mode
_) ->
[Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a scatter with anything else than a scatter or a map"
(SOAC SOACS
_, SOAC.Scatter {}, Mode
_) ->
[Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a scatter with anything else than a scatter or a map"
(SOAC.Stream {}, SOAC.Stream {}, Mode
_) -> do
([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
[VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_nms SOAC SOACS
res_stream
(SOAC.Stream {}, SOAC SOACS
_, Mode
_) -> do
(SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
Mode
mode
([VName] -> Names
namesFromList ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
SOAC SOACS
soac_p'
FusedSOAC
ker
(SOAC SOACS
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_, Mode
_) | Just ([Scan SOACS], Lambda SOACS)
_ <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form -> do
(SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
if SOAC SOACS
soac_p' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_p
then
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
Mode
mode
([VName] -> Names
namesFromList ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
SOAC SOACS
soac_p'
FusedSOAC
ker
else [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."
(SOAC SOACS
_, SOAC.Stream {}, Mode
_) -> do
(SOAC SOACS
soac_c', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_c
if SOAC SOACS
soac_c' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_c
then
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
Mode
mode
([VName] -> Names
namesFromList ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
[VName]
outVars
SOAC SOACS
soac_p
(FusedSOAC -> TryFusion FusedSOAC)
-> FusedSOAC -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac_c', fsOutNames :: [VName]
fsOutNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker}
else [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."
(SOAC SOACS, SOAC SOACS, Mode)
_ -> [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse"
fuseStreamHelper ::
[VName] ->
Names ->
[VName] ->
[(VName, Ident)] ->
SOAC ->
SOAC ->
TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper
[VName]
out_kernms
Names
unfus_set
[VName]
outVars
[(VName, Ident)]
outPairs
(SOAC.Stream SubExp
w2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
(SOAC.Stream SubExp
_ Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) = do
let chunk1 :: Param (TypeBase Shape NoUniqueness)
chunk1 = [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a. HasCallStack => [a] -> a
head ([Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1
chunk2 :: Param (TypeBase Shape NoUniqueness)
chunk2 = [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a. HasCallStack => [a] -> a
head ([Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2
hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk2, Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk1)]
lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [LParam SOACS] -> [LParam SOACS]
forall a. HasCallStack => [a] -> [a]
tail ([LParam SOACS] -> [LParam SOACS])
-> [LParam SOACS] -> [LParam SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1}
lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [LParam SOACS] -> [LParam SOACS]
forall a. HasCallStack => [a] -> [a]
tail ([LParam SOACS] -> [LParam SOACS])
-> [LParam SOACS] -> [LParam SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam20}
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam1'
[]
[SubExp]
nes1
[Input]
inp1_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam2'
[]
[SubExp]
nes2
[Input]
inp2_arr
res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param (TypeBase Shape NoUniqueness)
chunk1 Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
res_lam'}
unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
unfus_accs) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
([VName], SOAC SOACS) -> TryFusion ([VName], SOAC SOACS)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
SubExp -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC SOACS
forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w2 Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
)
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC SOACS
_ SOAC SOACS
_ = [Char] -> TryFusion ([VName], SOAC SOACS)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot Fuse Streams!"
optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel Maybe [VName]
inp FusedSOAC
ker = do
(SOAC SOACS
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) (FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker)
FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC -> TryFusion FusedSOAC)
-> FusedSOAC -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac, fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
resTrans}
optimizeSOAC ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC SOACS
soac ArrayTransforms
os = do
(Bool, SOAC SOACS, ArrayTransforms)
res <- ((Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms))
-> (Bool, SOAC SOACS, ArrayTransforms)
-> [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
False, SOAC SOACS
soac, ArrayTransforms
os) [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations
case (Bool, SOAC SOACS, ArrayTransforms)
res of
(Bool
False, SOAC SOACS
_, ArrayTransforms
_) -> [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"No optimisation applied"
(Bool
True, SOAC SOACS
soac', ArrayTransforms
os') -> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
os')
where
comb :: (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f =
do
(SOAC SOACS
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f Maybe [VName]
inp SOAC SOACS
soac' ArrayTransforms
os
(Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
True, SOAC SOACS
soac'', ArrayTransforms
os'')
TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os')
type Optimization =
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizations :: [Optimization]
optimizations :: [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim]
iswim ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
| Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
Futhark.isScanSOAC ScremaForm SOACS
form,
Just (Pat (TypeBase Shape NoUniqueness)
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS
-> Maybe
(Pat (TypeBase Shape NoUniqueness), Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
let nes_idents :: [Ident]
nes_idents = (VName -> TypeBase Shape NoUniqueness -> Ident)
-> [VName] -> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
nes_names ([TypeBase Shape NoUniqueness] -> [Ident])
-> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
([Param (TypeBase Shape NoUniqueness)]
scan_acc_params, [Param (TypeBase Shape NoUniqueness)]
scan_elem_params) =
Int
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
[Param (TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
[Param (TypeBase Shape NoUniqueness)]))
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
[Param (TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
scan_fun
map_params :: [Param (TypeBase Shape NoUniqueness)]
map_params =
(Param (TypeBase Shape NoUniqueness)
-> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness)
-> Param (TypeBase Shape NoUniqueness)
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param (TypeBase Shape NoUniqueness)]
scan_acc_params
[Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape NoUniqueness)
-> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param (TypeBase Shape NoUniqueness)]
scan_elem_params
map_rettype :: [TypeBase Shape NoUniqueness]
map_rettype = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
scan_body :: Body SOACS
scan_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
scan_rettype :: [TypeBase Shape NoUniqueness]
scan_rettype = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun
scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS]
-> Body SOACS -> [TypeBase Shape NoUniqueness] -> Lambda SOACS
forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [LParam SOACS]
scan_params Body SOACS
scan_body [TypeBase Shape NoUniqueness]
scan_rettype
nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params
arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params
ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']
let map_body :: Body SOACS
map_body =
Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
( Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w Pat (TypeBase Shape NoUniqueness)
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
)
(Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat (TypeBase Shape NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase Shape NoUniqueness)
map_pat
map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS]
-> Body SOACS -> [TypeBase Shape NoUniqueness] -> Lambda SOACS
forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
map_params Body SOACS
map_body [TypeBase Shape NoUniqueness]
map_rettype
perm :: [Int]
perm = case Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun of
[] -> []
TypeBase Shape NoUniqueness
t : [TypeBase Shape NoUniqueness]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t]
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
map_cs [Int]
perm
)
iswim Maybe [VName]
_ SOAC SOACS
_ ArrayTransforms
_ =
[Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"ISWIM does not apply."
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
let t :: TypeBase Shape NoUniqueness
t = TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
LParam SOACS
param
in Param (TypeBase Shape NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
let t :: TypeBase Shape NoUniqueness
t = Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
LParam SOACS
param TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
in Param (TypeBase Shape NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}
setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo :: SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> Pat a -> Pat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)
commonTransforms ::
[VName] ->
[SOAC.Input] ->
(SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
where
inps' :: [(Bool, Input)]
inps' =
[ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
| Input
inp <- [Input]
inps
]
commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
case ((Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
where
inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
(Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
| ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
where
resDims :: Int
resDims = [TypeBase Shape NoUniqueness] -> Int
forall {shape} {u}. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
[] -> Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [TypeBase Shape NoUniqueness]
forall {k} (rep :: k). Nesting rep -> [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType Nesting SOACS
nest
minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts
pullRearrange ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange SOAC SOACS
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
SOAC.Rearrange Certs
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let
perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
where
r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
SOAC SOACS
soac' <-
MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
[Input]
inputs' [Input] -> MapNest -> MapNest
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
else [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull transpose"
pushRearrange ::
[VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC SOACS
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let invertRearrange :: ArrayTransform
invertRearrange = Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
SOAC SOACS
soac' <-
MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
[Input]
inputs'
[Input] -> MapNest -> MapNest
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
else [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot push transpose"
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest.MapNest
SubExp
w
Lambda SOACS
body
( (Nesting SOACS -> [TypeBase Shape NoUniqueness] -> Nesting SOACS)
-> [Nesting SOACS]
-> [[TypeBase Shape NoUniqueness]]
-> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
Nesting SOACS -> [TypeBase Shape NoUniqueness] -> Nesting SOACS
forall {k} {k} {rep :: k} {rep :: k}.
Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType
[Nesting SOACS]
nestings
([[TypeBase Shape NoUniqueness]] -> [Nesting SOACS])
-> [[TypeBase Shape NoUniqueness]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int
-> [[TypeBase Shape NoUniqueness]]
-> [[TypeBase Shape NoUniqueness]]
forall a. Int -> [a] -> [a]
drop Int
1
([[TypeBase Shape NoUniqueness]]
-> [[TypeBase Shape NoUniqueness]])
-> [[TypeBase Shape NoUniqueness]]
-> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> a -> b
$ ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [[TypeBase Shape NoUniqueness]]
forall a. (a -> a) -> a -> [a]
iterate ((TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
ts
)
[Input]
inps
where
origts :: [TypeBase Shape NoUniqueness]
origts = MapNest -> [TypeBase Shape NoUniqueness]
forall rep. MapNest rep -> [TypeBase Shape NoUniqueness]
MapNest.typeOf MapNest
nest
rearrangeType' :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' TypeBase Shape NoUniqueness
t = [Int] -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) [Int]
perm) TypeBase Shape NoUniqueness
t
ts :: [TypeBase Shape NoUniqueness]
ts = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' [TypeBase Shape NoUniqueness]
origts
setReturnType :: Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType Nesting rep
nesting [TypeBase Shape NoUniqueness]
t' =
Nesting rep
nesting {nestingReturnType :: [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType = [TypeBase Shape NoUniqueness]
t'}
fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
[Int]
perm : [[Int]]
_ -> do
[Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
([Int], [Input]) -> Maybe ([Int], [Input])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [Input]
inps')
[[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
where
exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray
inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
| ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certs
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing
fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
| Int
r <- Input -> Int
SOAC.inputRank Input
inp,
Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
| Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing
pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
| Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
(TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam = do
let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
[] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
SubExp
d : [SubExp]
_ -> SubExp
d
trInput :: Input -> Input
trInput Input
inp
| TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Input -> TypeBase Shape NoUniqueness
SOAC.inputType Input
inp) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape) Input
inp
| Bool
otherwise =
ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) Input
inp
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
trInput [Input]
inps
inputTypes :: [TypeBase Shape NoUniqueness]
inputTypes = (Input -> TypeBase Shape NoUniqueness)
-> [Input] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase Shape NoUniqueness
SOAC.inputType [Input]
inputs'
let outersoac ::
([SOAC.Input] -> SOAC) ->
(SubExp, [SubExp]) ->
TryFusion ([SOAC.Input] -> SOAC)
outersoac :: ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac [Input] -> SOAC SOACS
inner (SubExp
w, [SubExp]
outershape) = do
let addDims :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims TypeBase Shape NoUniqueness
t = TypeBase Shape NoUniqueness
-> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape NoUniqueness
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
retTypes :: [TypeBase Shape NoUniqueness]
retTypes = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam
[Param (TypeBase Shape NoUniqueness)]
ps <- [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TryFusion [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
inputTypes ((TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TryFusion [Param (TypeBase Shape NoUniqueness)])
-> (TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TryFusion [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
inpt ->
[Char]
-> TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"pullReshape_param" (TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$
Int -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) TypeBase Shape NoUniqueness
inpt
Body SOACS
inner_body <-
Builder SOACS (Body SOACS) -> TryFusion (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (Body SOACS) -> TryFusion (Body SOACS))
-> Builder SOACS (Body SOACS) -> TryFusion (Body SOACS)
forall a b. (a -> b) -> a -> b
$
[BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))]
-> BuilderT
SOACS
(State VNameSource)
(Body (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp (SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource)))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC SOACS
inner ([Input] -> SOAC SOACS) -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> Input)
-> [Param (TypeBase Shape NoUniqueness)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input)
-> (Param (TypeBase Shape NoUniqueness) -> Ident)
-> Param (TypeBase Shape NoUniqueness)
-> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param (TypeBase Shape NoUniqueness)]
ps]
let inner_fun :: Lambda SOACS
inner_fun =
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
ps,
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
retTypes,
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
inner_body
}
([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
inner_fun
[Input] -> SOAC SOACS
op' <-
(([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC SOACS)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Input] -> SOAC SOACS
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC SOACS
_ ArrayTransforms
_ = [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull reshape"
exposeInputs ::
[VName] ->
FusedSOAC ->
TryFusion (FusedSOAC, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
inpIds FusedSOAC
ker =
(FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms))
-> TryFusion FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pushRearrange')
TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms))
-> TryFusion FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pullRearrange')
TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker
where
ot :: ArrayTransforms
ot = FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker
pushRearrange' :: TryFusion FusedSOAC
pushRearrange' = do
(SOAC SOACS
soac', ArrayTransforms
ot') <- [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
FusedSOAC
ker
{ fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
ot'
}
pullRearrange' :: TryFusion FusedSOAC
pullRearrange' = do
(SOAC SOACS
soac', ArrayTransforms
ot') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
[Char] -> TryFusion ()
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"pullRearrange was not enough"
FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
FusedSOAC
ker
{ fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
SOAC.noTransforms
}
exposeInputs' :: FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker' =
case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [Input]
inputs FusedSOAC
ker' of
(ArrayTransforms
ot', [Input]
inps')
| (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
(FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC
ker' {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps' [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker'}, ArrayTransforms
ot')
(ArrayTransforms, [Input])
_ -> [Char] -> TryFusion (FusedSOAC, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot expose"
exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds
outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange, SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape]
pullOutputTransforms ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall {t} {t}.
[t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers
where
attempt :: [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [] t
_ t
_ = [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull anything"
attempt (t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps) t
soac t
ots =
do
(SOAC SOACS
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p t
soac t
ots
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
SOAC.noTransforms)
else SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac' ArrayTransforms
ots' TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps t
soac t
ots