{-# LANGUAGE TypeFamilies #-}

-- | Facilities for fusing two SOACs.
--
-- When the fusion algorithm decides that it's worth fusing two SOAC
-- statements, this is the module that tries to see if that's
-- possible.  May involve massaging either producer or consumer in
-- various ways.
module Futhark.Optimise.Fusion.TryFusion
  ( FusedSOAC (..),
    Mode (..),
    attemptFusion,
  )
where

import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.MapNest qualified as MapNest
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)

newtype TryFusion a
  = TryFusion
      ( ReaderT
          (Scope SOACS)
          (StateT VNameSource Maybe)
          a
      )
  deriving
    ( (forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
fmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
<$ :: forall a b. a -> TryFusion b -> TryFusion a
Functor,
      Functor TryFusion
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
    (a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> TryFusion a
pure :: forall a. a -> TryFusion a
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
liftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
Applicative,
      Applicative TryFusion
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
$cempty :: forall a. TryFusion a
empty :: forall a. TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
$csome :: forall a. TryFusion a -> TryFusion [a]
some :: forall a. TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
many :: forall a. TryFusion a -> TryFusion [a]
Alternative,
      Applicative TryFusion
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$creturn :: forall a. a -> TryFusion a
return :: forall a. a -> TryFusion a
Monad,
      Monad TryFusion
Monad TryFusion
-> (forall a. [Char] -> TryFusion a) -> MonadFail TryFusion
forall a. [Char] -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. [Char] -> m a) -> MonadFail m
$cfail :: forall a. [Char] -> TryFusion a
fail :: forall a. [Char] -> TryFusion a
MonadFail,
      Monad TryFusion
TryFusion VNameSource
Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: TryFusion VNameSource
getNameSource :: TryFusion VNameSource
$cputNameSource :: VNameSource -> TryFusion ()
putNameSource :: VNameSource -> TryFusion ()
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

tryFusion ::
  (MonadFreshNames m) =>
  TryFusion a ->
  Scope SOACS ->
  m (Maybe a)
tryFusion :: forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
    Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
    Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)

liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: forall a. Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = [Char] -> TryFusion a
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

type SOAC = SOAC.SOAC SOACS

type MapNest = MapNest.MapNest SOACS

inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia TypeBase Shape NoUniqueness
iat) =
  case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
    ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts' VName
ia TypeBase Shape NoUniqueness
iat)
    ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing

-- | A fused SOAC contains a bit of extra information.
data FusedSOAC = FusedSOAC
  { -- | The actual SOAC.
    FusedSOAC -> SOAC SOACS
fsSOAC :: SOAC,
    -- | A transformation to be applied to *all* results of the SOAC.
    FusedSOAC -> ArrayTransforms
fsOutputTransform :: SOAC.ArrayTransforms,
    -- | The outputs of the SOAC (i.e. the names in the pattern that
    -- the result of this SOAC should be bound to).
    FusedSOAC -> [VName]
fsOutNames :: [VName]
  }
  deriving (Int -> FusedSOAC -> ShowS
[FusedSOAC] -> ShowS
FusedSOAC -> [Char]
(Int -> FusedSOAC -> ShowS)
-> (FusedSOAC -> [Char])
-> ([FusedSOAC] -> ShowS)
-> Show FusedSOAC
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FusedSOAC -> ShowS
showsPrec :: Int -> FusedSOAC -> ShowS
$cshow :: FusedSOAC -> [Char]
show :: FusedSOAC -> [Char]
$cshowList :: [FusedSOAC] -> ShowS
showList :: [FusedSOAC] -> ShowS
Show)

inputs :: FusedSOAC -> [SOAC.Input]
inputs :: FusedSOAC -> [Input]
inputs = SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs (SOAC SOACS -> [Input])
-> (FusedSOAC -> SOAC SOACS) -> FusedSOAC -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedSOAC -> SOAC SOACS
fsSOAC

setInputs :: [SOAC.Input] -> FusedSOAC -> FusedSOAC
setInputs :: [Input] -> FusedSOAC -> FusedSOAC
setInputs [Input]
inps FusedSOAC
ker = FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}

tryOptimizeSOAC ::
  Mode ->
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
tryOptimizeSOAC :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeSOAC Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  (SOAC SOACS
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC SOACS
soac ArrayTransforms
forall a. Monoid a => a
mempty
  let ker' :: FusedSOAC
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) [Input] -> FusedSOAC -> FusedSOAC
`setInputs` FusedSOAC
ker
      outIdents :: [Ident]
outIdents = (VName -> TypeBase Shape NoUniqueness -> Ident)
-> [VName] -> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase Shape NoUniqueness] -> [Ident])
-> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
      ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
  Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
  where
    addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
      | Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
          ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
      | Bool
otherwise =
          Input
inp

tryOptimizeKernel ::
  Mode ->
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
tryOptimizeKernel :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeKernel Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  FusedSOAC
ker' <- Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedSOAC
ker
  Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'

tryExposeInputs ::
  Mode ->
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
tryExposeInputs :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryExposeInputs Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  (FusedSOAC
ker', ArrayTransforms
ots) <- [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
outVars FusedSOAC
ker
  if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
    then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'
    else do
      Bool -> TryFusion ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> TryFusion ()) -> Bool -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ Names
unfus_nms Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
      (SOAC SOACS
soac', ArrayTransforms
ots') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac ArrayTransforms
ots
      let outIdents :: [Ident]
outIdents = (VName -> TypeBase Shape NoUniqueness -> Ident)
-> [VName] -> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase Shape NoUniqueness] -> [Ident])
-> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
          ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
      if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
        then Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
        else [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"tryExposeInputs could not pull SOAC transforms"

fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker =
  FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS -> SOAC SOACS
fixInputTypes' (SOAC SOACS -> SOAC SOACS) -> SOAC SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}
  where
    fixInputTypes' :: SOAC SOACS -> SOAC SOACS
fixInputTypes' SOAC SOACS
soac =
      (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac) [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC SOACS
soac
    fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
_)
      | Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
          ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts VName
v (TypeBase Shape NoUniqueness -> Input)
-> TypeBase Shape NoUniqueness -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase Shape NoUniqueness
identType Ident
v'
    fixInputType Input
inp = Input
inp

applyFusionRules ::
  Mode ->
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
applyFusionRules :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker =
  Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeSOAC Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
    TryFusion FusedSOAC -> TryFusion FusedSOAC -> TryFusion FusedSOAC
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryOptimizeKernel Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
    TryFusion FusedSOAC -> TryFusion FusedSOAC -> TryFusion FusedSOAC
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
    TryFusion FusedSOAC -> TryFusion FusedSOAC -> TryFusion FusedSOAC
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
tryExposeInputs Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker

-- | Whether we are doing horizontal or vertical fusion.  Note that
-- vertical also includes "diagonal" fusion, where some producer
-- results are also produced by the final SOAC.
data Mode = Horizontal | Vertical

-- | Attempt fusing the producer into the consumer.
attemptFusion ::
  (HasScope SOACS m, MonadFreshNames m) =>
  Mode ->
  -- | Outputs of the producer that should still be output by the
  -- fusion result (corresponding to "diagonal fusion").
  Names ->
  -- | The outputs of the SOAC.
  [VName] ->
  SOAC ->
  FusedSOAC ->
  m (Maybe FusedSOAC)
attemptFusion :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> m (Maybe FusedSOAC)
attemptFusion Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  TryFusion FusedSOAC -> Scope SOACS -> m (Maybe FusedSOAC)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
applyFusionRules Mode
mode Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker) Scope SOACS
scope

-- | Check that the consumer does not use any scan or reduce results.
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK ([VName]
nonmap_outs, [VName]
_map_outs) FusedSOAC
ker =
  (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
nonmap_outs) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)

-- | Check that the consumer uses all the outputs of the producer unmodified.
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
  where
    inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)

-- | The brain of this module: Fusing a SOAC with a Kernel.
fuseSOACwithKer ::
  Mode ->
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
fuseSOACwithKer :: Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer Mode
mode Names
unfus_set [VName]
outVars SOAC SOACS
soac_p FusedSOAC
ker = do
  -- We are fusing soac_p into soac_c, i.e, the output of soac_p is going
  -- into soac_c.
  let soac_c :: SOAC SOACS
soac_c = FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker
      inp_p_arr :: [Input]
inp_p_arr = SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_p
      inp_c_arr :: [Input]
inp_c_arr = SOAC SOACS -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_c
      lam_p :: Lambda SOACS
lam_p = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_p
      lam_c :: Lambda SOACS
lam_c = SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_c
      w :: SubExp
w = SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p
      returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
      success :: [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_outnms SOAC SOACS
res_soac = do
        -- Avoid name duplication, because the producer lambda is not
        -- removed from the program until much later.
        Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
res_soac
        FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC -> TryFusion FusedSOAC)
-> FusedSOAC -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
          FusedSOAC
ker
            { fsSOAC :: SOAC SOACS
fsSOAC = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall rep. Lambda rep -> SOAC rep -> SOAC rep
`SOAC.setLambda` SOAC SOACS
res_soac,
              fsOutNames :: [VName]
fsOutNames = [VName]
res_outnms
            }

  -- Can only fuse SOACs with same width.
  Bool -> TryFusion ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> TryFusion ()) -> Bool -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c

  -- If we are getting rid of a producer output, then it must be used
  -- exclusively without any transformations.
  let ker_inputs :: [VName]
ker_inputs = (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
      okInput :: VName -> Input -> Bool
okInput VName
v Input
inp = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= Input -> VName
SOAC.inputArray Input
inp Bool -> Bool -> Bool
|| Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Input -> Maybe VName
SOAC.isVarishInput Input
inp)
      inputOrUnfus :: VName -> Bool
inputOrUnfus VName
v = (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Input -> Bool
okInput VName
v) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) Bool -> Bool -> Bool
|| VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
ker_inputs

  Bool -> TryFusion ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> TryFusion ()) -> Bool -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inputOrUnfus [VName]
outVars

  [(VName, Ident)]
outPairs <- [(VName, TypeBase Shape NoUniqueness)]
-> ((VName, TypeBase Shape NoUniqueness)
    -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase Shape NoUniqueness]
 -> [(VName, TypeBase Shape NoUniqueness)])
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p) (((VName, TypeBase Shape NoUniqueness) -> TryFusion (VName, Ident))
 -> TryFusion [(VName, Ident)])
-> ((VName, TypeBase Shape NoUniqueness)
    -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, TypeBase Shape NoUniqueness
t) -> do
    VName
outVar' <- [Char] -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> TryFusion VName) -> [Char] -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
outVar [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_elem"
    (VName, Ident) -> TryFusion (VName, Ident)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
outVar, VName -> TypeBase Shape NoUniqueness -> Ident
Ident VName
outVar' TypeBase Shape NoUniqueness
t)

  let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
        let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
            ([VName]
extra_nms, [TypeBase Shape NoUniqueness]
extra_rtps) =
              [(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, TypeBase Shape NoUniqueness)]
 -> ([VName], [TypeBase Shape NoUniqueness]))
-> [(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$
                ((VName, TypeBase Shape NoUniqueness) -> Bool)
-> [(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, TypeBase Shape NoUniqueness) -> VName)
-> (VName, TypeBase Shape NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, TypeBase Shape NoUniqueness) -> VName
forall a b. (a, b) -> a
fst) ([(VName, TypeBase Shape NoUniqueness)]
 -> [(VName, TypeBase Shape NoUniqueness)])
-> [(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
                  [VName]
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase Shape NoUniqueness]
 -> [(VName, TypeBase Shape NoUniqueness)])
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
                    (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$
                      SOAC SOACS -> [TypeBase Shape NoUniqueness]
forall rep. SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p
            res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
res_lam [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase Shape NoUniqueness]
extra_rtps}
         in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)

  case (SOAC SOACS
soac_c, SOAC SOACS
soac_p, Mode
mode) of
    (SOAC SOACS, SOAC SOACS, Mode)
_ | SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c -> [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC widths must match."
    (SOAC SOACS
_, SOAC SOACS
_, Mode
Horizontal)
      | Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker) ->
          [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Horizontal fusion is invalid in the presence of output transforms."
    (SOAC SOACS
_, SOAC SOACS
_, Mode
Vertical)
      | Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty,
        Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker) ->
          [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot perform diagonal fusion in the presence of output transforms."
    ( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
      SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_,
      Mode
_
      )
        | ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK (Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedSOAC
ker -> do
            let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
                red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
                scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
                scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
                (Lambda SOACS
res_lam', [Input]
new_inp) =
                  Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
                    Names
unfus_set
                    [VName]
outVars
                    Lambda SOACS
lam_p
                    [SubExp]
scan_nes_p
                    [SubExp]
red_nes_p
                    [Input]
inp_p_arr
                    [(VName, Ident)]
outPairs
                    Lambda SOACS
lam_c
                    [SubExp]
scan_nes_c
                    [SubExp]
red_nes_c
                    [Input]
inp_c_arr
                ([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
                  Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
                ([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
                  Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker
                unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success
              ( [VName]
soac_p_scanout
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
                  [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
              )
              (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma
                SubExp
w
                ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
                [Input]
new_inp

    ------------------
    -- Scatter fusion --
    ------------------

    -- Map-Scatter fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Scatter is not writing to any array used in
    -- the Map.
    ( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_,
      Mode
_
      )
        | Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the scatter, i.e., not used elsewhere.
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
            let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
              SubExp
-> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests

    -- Map-Hist fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Hist is not writing to any array used in
    -- the Map.
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_,
      Mode
_
      )
        | Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the hist, i.e., not used elsewhere.
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
            let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
              SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp

    -- Hist-Hist fusion
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
      SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_,
      Mode
Horizontal
      ) -> do
        let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
            c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
            (Body SOACS
body_p, Body SOACS
body_c) = (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
            body' :: Body SOACS
body' =
              Body
                { bodyDec :: BodyDec SOACS
bodyDec = Body SOACS -> BodyDec SOACS
forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
body_p, -- body_p and body_c have the same decorations
                  bodyStms :: Stms SOACS
bodyStms = Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_c,
                  bodyResult :: Result
bodyResult =
                    Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_c)
                      Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
                      Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_c)
                      Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
                }
            lam' :: Lambda SOACS
lam' =
              Lambda
                { lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
                  lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
                  lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType =
                    Int -> TypeBase Shape NoUniqueness -> [TypeBase Shape NoUniqueness]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
                      [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c)
                      [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
                }
        [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
          SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)

    -- Scatter-write fusion.
    ( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
      SOAC.Scatter
        SubExp
_len_p
        Lambda SOACS
_lam_p
        [Input]
ivs_p
        [(Shape, Int, VName)]
as_p,
      Mode
Horizontal
      ) -> do
        let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
              where
                ([a]
xs_indices, [a]
xs_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
                ([a]
ys_indices, [a]
ys_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
        let (Body SOACS
body_p, Body SOACS
body_c) = (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
        let body' :: Body SOACS
body' =
              Body
                { bodyDec :: BodyDec SOACS
bodyDec = Body SOACS -> BodyDec SOACS
forall rep. Body rep -> BodyDec rep
bodyDec Body SOACS
body_p, -- body_p and body_c have the same decorations
                  bodyStms :: Stms SOACS
bodyStms = Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body_c,
                  bodyResult :: Result
bodyResult = [(Shape, Int, VName)]
-> Result -> [(Shape, Int, VName)] -> Result -> Result
forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_c) [(Shape, Int, VName)]
as_p (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body_p)
                }
        let lam' :: Lambda SOACS
lam' =
              Lambda
                { lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
                  lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
                  lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [(Shape, Int, VName)]
-> [TypeBase Shape NoUniqueness]
-> [(Shape, Int, VName)]
-> [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness]
forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
                }
        [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedSOAC)
-> SOAC SOACS -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$
          SubExp
-> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c [(Shape, Int, VName)]
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
    (SOAC.Scatter {}, SOAC SOACS
_, Mode
_) ->
      [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a scatter with anything else than a scatter or a map"
    (SOAC SOACS
_, SOAC.Scatter {}, Mode
_) ->
      [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a scatter with anything else than a scatter or a map"
    ----------------------------
    -- Stream-Stream Fusions: --
    ----------------------------
    (SOAC.Stream {}, SOAC.Stream {}, Mode
_) -> do
      -- fuse two SEQUENTIAL streams
      ([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
      [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_nms SOAC SOACS
res_stream
    -------------------------------------------------------------------
    --- If one is a stream, translate the other to a stream as well.---
    --- This does not get in trouble (infinite computation) because ---
    ---   scan's translation to Stream introduces a hindrance to    ---
    ---   (horizontal fusion), hence repeated application is for the---
    ---   moment impossible. However, if with a dependence-graph rep---
    ---   we could run in an infinite recursion, i.e., repeatedly   ---
    ---   fusing map o scan into an infinity of Stream levels!      ---
    -------------------------------------------------------------------
    (SOAC.Stream {}, SOAC SOACS
_, Mode
_) -> do
      -- If this rule is matched then soac_p is NOT a stream.
      -- To fuse a stream kernel, we transform soac_p to a stream, which
      -- borrows the sequential/parallel property of the soac_c Stream,
      -- and recursively perform stream-stream fusion.
      (SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
      Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
        Mode
mode
        ([VName] -> Names
namesFromList ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
        ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
        SOAC SOACS
soac_p'
        FusedSOAC
ker
    (SOAC SOACS
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_, Mode
_) | Just ([Scan SOACS], Lambda SOACS)
_ <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form -> do
      -- A Scan soac can be currently only fused as a (sequential) stream,
      -- hence it is first translated to a (sequential) Stream and then
      -- fusion with a kernel is attempted.
      (SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
      if SOAC SOACS
soac_p' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_p
        then
          Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
            Mode
mode
            ([VName] -> Names
namesFromList ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
            ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
            SOAC SOACS
soac_p'
            FusedSOAC
ker
        else [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."
    (SOAC SOACS
_, SOAC.Stream {}, Mode
_) -> do
      -- If it reached this case then soac_c is NOT a Stream kernel,
      -- hence transform the kernel's soac to a stream and attempt
      -- stream-stream fusion recursivelly.
      -- The newly created stream corresponding to soac_c borrows the
      -- sequential/parallel property of the soac_p stream.
      (SOAC SOACS
soac_c', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_c
      if SOAC SOACS
soac_c' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_c
        then
          Mode
-> Names
-> [VName]
-> SOAC SOACS
-> FusedSOAC
-> TryFusion FusedSOAC
fuseSOACwithKer
            Mode
mode
            ([VName] -> Names
namesFromList ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
            [VName]
outVars
            SOAC SOACS
soac_p
            (FusedSOAC -> TryFusion FusedSOAC)
-> FusedSOAC -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac_c', fsOutNames :: [VName]
fsOutNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker}
        else [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."

    ---------------------------------
    --- DEFAULT, CANNOT FUSE CASE ---
    ---------------------------------
    (SOAC SOACS, SOAC SOACS, Mode)
_ -> [Char] -> TryFusion FusedSOAC
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse"

fuseStreamHelper ::
  [VName] ->
  Names ->
  [VName] ->
  [(VName, Ident)] ->
  SOAC ->
  SOAC ->
  TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper
  [VName]
out_kernms
  Names
unfus_set
  [VName]
outVars
  [(VName, Ident)]
outPairs
  (SOAC.Stream SubExp
w2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
  (SOAC.Stream SubExp
_ Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) = do
    -- very similar to redomap o redomap composition, but need
    -- to remove first the `chunk' parameters of streams'
    -- lambdas and put them in the resulting stream lambda.
    let chunk1 :: Param (TypeBase Shape NoUniqueness)
chunk1 = [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a. HasCallStack => [a] -> a
head ([Param (TypeBase Shape NoUniqueness)]
 -> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1
        chunk2 :: Param (TypeBase Shape NoUniqueness)
chunk2 = [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a. HasCallStack => [a] -> a
head ([Param (TypeBase Shape NoUniqueness)]
 -> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> Param (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2
        hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk2, Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk1)]
        lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
        lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [LParam SOACS] -> [LParam SOACS]
forall a. HasCallStack => [a] -> [a]
tail ([LParam SOACS] -> [LParam SOACS])
-> [LParam SOACS] -> [LParam SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1}
        lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [LParam SOACS] -> [LParam SOACS]
forall a. HasCallStack => [a] -> [a]
tail ([LParam SOACS] -> [LParam SOACS])
-> [LParam SOACS] -> [LParam SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam20}
        (Lambda SOACS
res_lam', [Input]
new_inp) =
          Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
            Names
unfus_set
            [VName]
outVars
            Lambda SOACS
lam1'
            []
            [SubExp]
nes1
            [Input]
inp1_arr
            [(VName, Ident)]
outPairs
            Lambda SOACS
lam2'
            []
            [SubExp]
nes2
            [Input]
inp2_arr
        res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param (TypeBase Shape NoUniqueness)
chunk1 Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
res_lam'}
        unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
        unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
unfus_accs) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
    ([VName], SOAC SOACS) -> TryFusion ([VName], SOAC SOACS)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
        SubExp -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC SOACS
forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w2 Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
      )
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC SOACS
_ SOAC SOACS
_ = [Char] -> TryFusion ([VName], SOAC SOACS)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot Fuse Streams!"

-- Here follows optimizations and transforms to expose fusability.

optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel Maybe [VName]
inp FusedSOAC
ker = do
  (SOAC SOACS
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) (FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker)
  FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC -> TryFusion FusedSOAC)
-> FusedSOAC -> TryFusion FusedSOAC
forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac, fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
resTrans}

optimizeSOAC ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC SOACS
soac ArrayTransforms
os = do
  (Bool, SOAC SOACS, ArrayTransforms)
res <- ((Bool, SOAC SOACS, ArrayTransforms)
 -> (Maybe [VName]
     -> SOAC SOACS
     -> ArrayTransforms
     -> TryFusion (SOAC SOACS, ArrayTransforms))
 -> TryFusion (Bool, SOAC SOACS, ArrayTransforms))
-> (Bool, SOAC SOACS, ArrayTransforms)
-> [Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
False, SOAC SOACS
soac, ArrayTransforms
os) [Maybe [VName]
 -> SOAC SOACS
 -> ArrayTransforms
 -> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations
  case (Bool, SOAC SOACS, ArrayTransforms)
res of
    (Bool
False, SOAC SOACS
_, ArrayTransforms
_) -> [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"No optimisation applied"
    (Bool
True, SOAC SOACS
soac', ArrayTransforms
os') -> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
os')
  where
    comb :: (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f =
      do
        (SOAC SOACS
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f Maybe [VName]
inp SOAC SOACS
soac' ArrayTransforms
os
        (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
True, SOAC SOACS
soac'', ArrayTransforms
os'')
        TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os')

type Optimization =
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)

optimizations :: [Optimization]
optimizations :: [Maybe [VName]
 -> SOAC SOACS
 -> ArrayTransforms
 -> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim]

iswim ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
  | Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
Futhark.isScanSOAC ScremaForm SOACS
form,
    Just (Pat (TypeBase Shape NoUniqueness)
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS
-> Maybe
     (Pat (TypeBase Shape NoUniqueness), Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
    Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
      let nes_idents :: [Ident]
nes_idents = (VName -> TypeBase Shape NoUniqueness -> Ident)
-> [VName] -> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
nes_names ([TypeBase Shape NoUniqueness] -> [Ident])
-> [TypeBase Shape NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
          map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
          map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
          ([Param (TypeBase Shape NoUniqueness)]
scan_acc_params, [Param (TypeBase Shape NoUniqueness)]
scan_elem_params) =
            Int
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param (TypeBase Shape NoUniqueness)]
 -> ([Param (TypeBase Shape NoUniqueness)],
     [Param (TypeBase Shape NoUniqueness)]))
-> [Param (TypeBase Shape NoUniqueness)]
-> ([Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
scan_fun
          map_params :: [Param (TypeBase Shape NoUniqueness)]
map_params =
            (Param (TypeBase Shape NoUniqueness)
 -> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness)
-> Param (TypeBase Shape NoUniqueness)
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param (TypeBase Shape NoUniqueness)]
scan_acc_params
              [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape NoUniqueness)
 -> Param (TypeBase Shape NoUniqueness))
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param (TypeBase Shape NoUniqueness)]
scan_elem_params
          map_rettype :: [TypeBase Shape NoUniqueness]
map_rettype = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun

          scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
          scan_body :: Body SOACS
scan_body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
          scan_rettype :: [TypeBase Shape NoUniqueness]
scan_rettype = Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun
          scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS]
-> Body SOACS -> [TypeBase Shape NoUniqueness] -> Lambda SOACS
forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [LParam SOACS]
scan_params Body SOACS
scan_body [TypeBase Shape NoUniqueness]
scan_rettype
          nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params
          arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params

      ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']

      let map_body :: Body SOACS
map_body =
            Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
              ( Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
                  Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w Pat (TypeBase Shape NoUniqueness)
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                    Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
                      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
              )
              (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
              ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat (TypeBase Shape NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase Shape NoUniqueness)
map_pat
          map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS]
-> Body SOACS -> [TypeBase Shape NoUniqueness] -> Lambda SOACS
forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
map_params Body SOACS
map_body [TypeBase Shape NoUniqueness]
map_rettype
          perm :: [Int]
perm = case Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun of -- instead of map_fun
            [] -> []
            TypeBase Shape NoUniqueness
t : [TypeBase Shape NoUniqueness]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t]

      (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
          ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
map_cs [Int]
perm
        )
iswim Maybe [VName]
_ SOAC SOACS
_ ArrayTransforms
_ =
  [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"ISWIM does not apply."

removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
  let t :: TypeBase Shape NoUniqueness
t = TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
LParam SOACS
param
   in Param (TypeBase Shape NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}

setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
  let t :: TypeBase Shape NoUniqueness
t = Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
LParam SOACS
param TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
   in Param (TypeBase Shape NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}

setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo :: SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> Pat a -> Pat b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeBase Shape NoUniqueness
-> SubExp -> TypeBase Shape NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)

-- Now for fiddling with transpositions...

commonTransforms ::
  [VName] ->
  [SOAC.Input] ->
  (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
  where
    inps' :: [(Bool, Input)]
inps' =
      [ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
        | Input
inp <- [Input]
inps
      ]

commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
  case ((Maybe ArrayTransform, [(Bool, Input)])
 -> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
    Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
    Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
  where
    inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
      case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
        (Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
          | ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
    inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)

mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
  Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  where
    resDims :: Int
resDims = [TypeBase Shape NoUniqueness] -> Int
forall {shape} {u}. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
      [] -> Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [TypeBase Shape NoUniqueness]
forall {k} (rep :: k). Nesting rep -> [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType Nesting SOACS
nest
    minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
    minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts

pullRearrange ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange SOAC SOACS
soac ArrayTransforms
ots = do
  MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
  SOAC.Rearrange Certs
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
  if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let -- Expand perm to cover the full extent of the input dimensionality
          perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
            where
              r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
          addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
          inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
      SOAC SOACS
soac' <-
        MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
          [Input]
inputs' [Input] -> MapNest -> MapNest
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
    else [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull transpose"

pushRearrange ::
  [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC SOACS
soac ArrayTransforms
ots = do
  MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
  ([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
  if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let invertRearrange :: ArrayTransform
invertRearrange = Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
      SOAC SOACS
soac' <-
        MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
          [Input]
inputs'
            [Input] -> MapNest -> MapNest
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
    else [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot push transpose"

-- | Actually also rearranges indices.
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
  SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest.MapNest
    SubExp
w
    Lambda SOACS
body
    ( (Nesting SOACS -> [TypeBase Shape NoUniqueness] -> Nesting SOACS)
-> [Nesting SOACS]
-> [[TypeBase Shape NoUniqueness]]
-> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
        Nesting SOACS -> [TypeBase Shape NoUniqueness] -> Nesting SOACS
forall {k} {k} {rep :: k} {rep :: k}.
Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType
        [Nesting SOACS]
nestings
        ([[TypeBase Shape NoUniqueness]] -> [Nesting SOACS])
-> [[TypeBase Shape NoUniqueness]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int
-> [[TypeBase Shape NoUniqueness]]
-> [[TypeBase Shape NoUniqueness]]
forall a. Int -> [a] -> [a]
drop Int
1
        ([[TypeBase Shape NoUniqueness]]
 -> [[TypeBase Shape NoUniqueness]])
-> [[TypeBase Shape NoUniqueness]]
-> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> a -> b
$ ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [[TypeBase Shape NoUniqueness]]
forall a. (a -> a) -> a -> [a]
iterate ((TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
ts
    )
    [Input]
inps
  where
    origts :: [TypeBase Shape NoUniqueness]
origts = MapNest -> [TypeBase Shape NoUniqueness]
forall rep. MapNest rep -> [TypeBase Shape NoUniqueness]
MapNest.typeOf MapNest
nest
    -- The permutation may be deeper than the rank of the type,
    -- but it is required that it is an identity permutation
    -- beyond that.  This is supposed to be checked as an
    -- invariant by whoever calls rearrangeReturnTypes.
    rearrangeType' :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' TypeBase Shape NoUniqueness
t = [Int] -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) [Int]
perm) TypeBase Shape NoUniqueness
t
    ts :: [TypeBase Shape NoUniqueness]
ts = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' [TypeBase Shape NoUniqueness]
origts

    setReturnType :: Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType Nesting rep
nesting [TypeBase Shape NoUniqueness]
t' =
      Nesting rep
nesting {nestingReturnType :: [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType = [TypeBase Shape NoUniqueness]
t'}

fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
  case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
    [Int]
perm : [[Int]]
_ -> do
      [Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
      ([Int], [Input]) -> Maybe ([Int], [Input])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [Input]
inps')
    [[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
  where
    exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray

    inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
      | ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certs
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing

    fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
      | Int
r <- Input -> Int
SOAC.inputRank Input
inp,
        Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
          Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
      | Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing

pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
  | Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
    SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
    (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam = do
      let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
            [] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
            SubExp
d : [SubExp]
_ -> SubExp
d
          trInput :: Input -> Input
trInput Input
inp
            | TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Input -> TypeBase Shape NoUniqueness
SOAC.inputType Input
inp) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
                ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape) Input
inp
            | Bool
otherwise =
                ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) Input
inp
          inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
trInput [Input]
inps
          inputTypes :: [TypeBase Shape NoUniqueness]
inputTypes = (Input -> TypeBase Shape NoUniqueness)
-> [Input] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase Shape NoUniqueness
SOAC.inputType [Input]
inputs'

      let outersoac ::
            ([SOAC.Input] -> SOAC) ->
            (SubExp, [SubExp]) ->
            TryFusion ([SOAC.Input] -> SOAC)
          outersoac :: ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac [Input] -> SOAC SOACS
inner (SubExp
w, [SubExp]
outershape) = do
            let addDims :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims TypeBase Shape NoUniqueness
t = TypeBase Shape NoUniqueness
-> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape NoUniqueness
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
                retTypes :: [TypeBase Shape NoUniqueness]
retTypes = (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam

            [Param (TypeBase Shape NoUniqueness)]
ps <- [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
    -> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TryFusion [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
inputTypes ((TypeBase Shape NoUniqueness
  -> TryFusion (Param (TypeBase Shape NoUniqueness)))
 -> TryFusion [Param (TypeBase Shape NoUniqueness)])
-> (TypeBase Shape NoUniqueness
    -> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TryFusion [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
inpt ->
              [Char]
-> TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"pullReshape_param" (TypeBase Shape NoUniqueness
 -> TryFusion (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> TryFusion (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$
                Int -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) TypeBase Shape NoUniqueness
inpt

            Body SOACS
inner_body <-
              Builder SOACS (Body SOACS) -> TryFusion (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (Body SOACS) -> TryFusion (Body SOACS))
-> Builder SOACS (Body SOACS) -> TryFusion (Body SOACS)
forall a b. (a -> b) -> a -> b
$
                [BuilderT
   SOACS
   (State VNameSource)
   (Exp (Rep (BuilderT SOACS (State VNameSource))))]
-> BuilderT
     SOACS
     (State VNameSource)
     (Body (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
     SOACS
     (State VNameSource)
     (Exp (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp (SOAC (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT
      SOACS
      (State VNameSource)
      (Exp (Rep (BuilderT SOACS (State VNameSource)))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
     SOACS
     (State VNameSource)
     (Exp (Rep (BuilderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC SOACS
inner ([Input] -> SOAC SOACS) -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape NoUniqueness) -> Input)
-> [Param (TypeBase Shape NoUniqueness)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input)
-> (Param (TypeBase Shape NoUniqueness) -> Ident)
-> Param (TypeBase Shape NoUniqueness)
-> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param (TypeBase Shape NoUniqueness)]
ps]
            let inner_fun :: Lambda SOACS
inner_fun =
                  Lambda
                    { lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
ps,
                      lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
retTypes,
                      lambdaBody :: Body SOACS
lambdaBody = Body SOACS
inner_body
                    }
            ([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
inner_fun

      [Input] -> SOAC SOACS
op' <-
        (([Input] -> SOAC SOACS)
 -> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC SOACS)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
          [SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
            Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Input] -> SOAC SOACS
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC SOACS
_ ArrayTransforms
_ = [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull reshape"

-- Tie it all together in exposeInputs (for making inputs to a
-- consumer available) and pullOutputTransforms (for moving
-- output-transforms of a producer to its inputs instead).

exposeInputs ::
  [VName] ->
  FusedSOAC ->
  TryFusion (FusedSOAC, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
inpIds FusedSOAC
ker =
  (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms))
-> TryFusion FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pushRearrange')
    TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms))
-> TryFusion FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pullRearrange')
    TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker
  where
    ot :: ArrayTransforms
ot = FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker

    pushRearrange' :: TryFusion FusedSOAC
pushRearrange' = do
      (SOAC SOACS
soac', ArrayTransforms
ot') <- [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
      FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        FusedSOAC
ker
          { fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
            fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
ot'
          }

    pullRearrange' :: TryFusion FusedSOAC
pullRearrange' = do
      (SOAC SOACS
soac', ArrayTransforms
ot') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
      Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> TryFusion ()
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"pullRearrange was not enough"
      FusedSOAC -> TryFusion FusedSOAC
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        FusedSOAC
ker
          { fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
            fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
SOAC.noTransforms
          }

    exposeInputs' :: FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker' =
      case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [Input]
inputs FusedSOAC
ker' of
        (ArrayTransforms
ot', [Input]
inps')
          | (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
              (FusedSOAC, ArrayTransforms)
-> TryFusion (FusedSOAC, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC
ker' {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps' [Input] -> SOAC SOACS -> SOAC SOACS
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker'}, ArrayTransforms
ot')
        (ArrayTransforms, [Input])
_ -> [Char] -> TryFusion (FusedSOAC, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot expose"

    exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
      | ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
    exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds

outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange, SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape]

pullOutputTransforms ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms = [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall {t} {t}.
[t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers
  where
    attempt :: [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [] t
_ t
_ = [Char] -> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. [Char] -> TryFusion a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull anything"
    attempt (t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps) t
soac t
ots =
      do
        (SOAC SOACS
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p t
soac t
ots
        if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
          then (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
SOAC.noTransforms)
          else SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac' ArrayTransforms
ots' TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. a -> TryFusion a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
        TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps t
soac t
ots