{-# LANGUAGE TypeFamilies #-}

-- | Facilities for fusing two SOACs.
--
-- When the fusion algorithm decides that it's worth fusing two SOAC
-- statements, this is the module that tries to see if that's
-- possible.  May involve massaging either producer or consumer in
-- various ways.
module Futhark.Optimise.Fusion.TryFusion
  ( FusedSOAC (..),
    attemptFusion,
  )
where

import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.MapNest qualified as MapNest
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import Futhark.IR.SOACS qualified as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)

newtype TryFusion a
  = TryFusion
      ( ReaderT
          (Scope SOACS)
          (StateT VNameSource Maybe)
          a
      )
  deriving
    ( forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
      Functor TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: forall a. a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
Applicative,
      Applicative TryFusion
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: forall a. TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: forall a. TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: forall a. TryFusion a
$cempty :: forall a. TryFusion a
Alternative,
      Applicative TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
Monad,
      Monad TryFusion
forall a. [Char] -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. [Char] -> m a) -> MonadFail m
fail :: forall a. [Char] -> TryFusion a
$cfail :: forall a. [Char] -> TryFusion a
MonadFail,
      Monad TryFusion
TryFusion VNameSource
VNameSource -> TryFusion ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

tryFusion ::
  MonadFreshNames m =>
  TryFusion a ->
  Scope SOACS ->
  m (Maybe a)
tryFusion :: forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  case forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
    Just (a
x, VNameSource
src') -> (forall a. a -> Maybe a
Just a
x, VNameSource
src')
    Maybe (a, VNameSource)
Nothing -> (forall a. Maybe a
Nothing, VNameSource
src)

liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: forall a. Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Nothing"
liftMaybe (Just a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

type SOAC = SOAC.SOAC SOACS

type MapNest = MapNest.MapNest SOACS

inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia TypeBase Shape NoUniqueness
iat) =
  case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
    ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts' VName
ia TypeBase Shape NoUniqueness
iat)
    ViewF
SOAC.EmptyF -> forall a. Maybe a
Nothing

-- | A fused SOAC contains a bit of extra information.
data FusedSOAC = FusedSOAC
  { -- | The actual SOAC.
    FusedSOAC -> SOAC SOACS
fsSOAC :: SOAC,
    -- | A transformation to be applied to *all* results of the SOAC.
    FusedSOAC -> ArrayTransforms
fsOutputTransform :: SOAC.ArrayTransforms,
    -- | The outputs of the SOAC (i.e. the names in the pattern that
    -- the result of this SOAC should be bound to).
    FusedSOAC -> [VName]
fsOutNames :: [VName]
  }
  deriving (Int -> FusedSOAC -> ShowS
[FusedSOAC] -> ShowS
FusedSOAC -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [FusedSOAC] -> ShowS
$cshowList :: [FusedSOAC] -> ShowS
show :: FusedSOAC -> [Char]
$cshow :: FusedSOAC -> [Char]
showsPrec :: Int -> FusedSOAC -> ShowS
$cshowsPrec :: Int -> FusedSOAC -> ShowS
Show)

inputs :: FusedSOAC -> [SOAC.Input]
inputs :: FusedSOAC -> [Input]
inputs = forall {k} (rep :: k). SOAC rep -> [Input]
SOAC.inputs forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedSOAC -> SOAC SOACS
fsSOAC

setInputs :: [SOAC.Input] -> FusedSOAC -> FusedSOAC
setInputs :: [Input] -> FusedSOAC -> FusedSOAC
setInputs [Input]
inps FusedSOAC
ker = FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps forall {k} (rep :: k). [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}

tryOptimizeSOAC ::
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
tryOptimizeSOAC :: Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  (SOAC SOACS
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC forall a. Maybe a
Nothing SOAC SOACS
soac forall a. Monoid a => a
mempty
  let ker' :: FusedSOAC
ker' = forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedSOAC -> [Input]
inputs FusedSOAC
ker) [Input] -> FusedSOAC -> FusedSOAC
`setInputs` FusedSOAC
ker
      outIdents :: [Ident]
outIdents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
      ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
  Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
  where
    addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
      | Input -> VName
SOAC.inputArray Input
inp forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
          ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
      | Bool
otherwise =
          Input
inp

tryOptimizeKernel ::
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
tryOptimizeKernel :: Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  FusedSOAC
ker' <- Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel (forall a. a -> Maybe a
Just [VName]
outVars) FusedSOAC
ker
  Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'

tryExposeInputs ::
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
tryExposeInputs :: Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  (FusedSOAC
ker', ArrayTransforms
ots) <- [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
outVars FusedSOAC
ker
  if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
    then Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker'
    else do
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Names
unfus_nms forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
      (SOAC SOACS
soac', ArrayTransforms
ots') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac ArrayTransforms
ots
      let outIdents :: [Ident]
outIdents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
outVars forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
          ker'' :: FusedSOAC
ker'' = [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker'
      if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
        then Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac' FusedSOAC
ker''
        else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"tryExposeInputs could not pull SOAC transforms"

fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC
fixInputTypes [Ident]
outIdents FusedSOAC
ker =
  FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS -> SOAC SOACS
fixInputTypes' forall a b. (a -> b) -> a -> b
$ FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker}
  where
    fixInputTypes' :: SOAC SOACS -> SOAC SOACS
fixInputTypes' SOAC SOACS
soac =
      forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (forall {k} (rep :: k). SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac) forall {k} (rep :: k). [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC SOACS
soac
    fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v TypeBase Shape NoUniqueness
_)
      | Just Ident
v' <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
          ArrayTransforms -> VName -> TypeBase Shape NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts VName
v forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase Shape NoUniqueness
identType Ident
v'
    fixInputType Input
inp = Input
inp

applyFusionRules ::
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
applyFusionRules :: Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker =
  Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker

-- | Attempt fusing the producer into the consumer.
attemptFusion ::
  (HasScope SOACS m, MonadFreshNames m) =>
  -- | Outputs of the producer that should still be output by the
  -- fusion result (corresponding to "diagonal fusion").
  Names ->
  -- | The outputs of the SOAC.
  [VName] ->
  SOAC ->
  FusedSOAC ->
  m (Maybe FusedSOAC)
attemptFusion :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Names -> [VName] -> SOAC SOACS -> FusedSOAC -> m (Maybe FusedSOAC)
attemptFusion Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker = do
  Scope SOACS
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac FusedSOAC
ker) Scope SOACS
scope

-- | Check that the consumer does not use any scan or reduce results.
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK ([VName]
nonmap_outs, [VName]
_map_outs) FusedSOAC
ker =
  forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
nonmap_outs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)

-- | Check that the consumer uses all the outputs of the producer unmodified.
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
  where
    inpIds :: [VName]
inpIds = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)

-- | The brain of this module: Fusing a SOAC with a Kernel.
fuseSOACwithKer ::
  Names ->
  [VName] ->
  SOAC ->
  FusedSOAC ->
  TryFusion FusedSOAC
fuseSOACwithKer :: Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC SOACS
soac_p FusedSOAC
ker = do
  -- We are fusing soac_p into soac_c, i.e, the output of soac_p is going
  -- into soac_c.
  let soac_c :: SOAC SOACS
soac_c = FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker
      inp_p_arr :: [Input]
inp_p_arr = forall {k} (rep :: k). SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_p
      horizFuse :: Bool
horizFuse = Names
unfus_set forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty
      inp_c_arr :: [Input]
inp_c_arr = forall {k} (rep :: k). SOAC rep -> [Input]
SOAC.inputs SOAC SOACS
soac_c
      lam_p :: Lambda SOACS
lam_p = forall {k} (rep :: k). SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_p
      lam_c :: Lambda SOACS
lam_c = forall {k} (rep :: k). SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
soac_c
      w :: SubExp
w = forall {k} (rep :: k). SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p
      returned_outvars :: [VName]
returned_outvars = forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
      success :: [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_outnms SOAC SOACS
res_soac = do
        -- Avoid name duplication, because the producer lambda is not
        -- removed from the program until much later.
        Lambda SOACS
uniq_lam <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SOAC rep -> Lambda rep
SOAC.lambda SOAC SOACS
res_soac
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
          FusedSOAC
ker
            { fsSOAC :: SOAC SOACS
fsSOAC = Lambda SOACS
uniq_lam forall {k} (rep :: k). Lambda rep -> SOAC rep -> SOAC rep
`SOAC.setLambda` SOAC SOACS
res_soac,
              fsOutNames :: [VName]
fsOutNames = [VName]
res_outnms
            }

  -- Can only fuse SOACs with same width.
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c

  -- If we are getting rid of a producer output, then it must be used
  -- without any transformation.
  let bare_inputs :: [VName]
bare_inputs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
      ker_inputs :: [VName]
ker_inputs = forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray (FusedSOAC -> [Input]
inputs FusedSOAC
ker)
      inputOrUnfus :: VName -> Bool
inputOrUnfus VName
v = VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
bare_inputs Bool -> Bool -> Bool
|| VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
ker_inputs

  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inputOrUnfus [VName]
outVars

  [(VName, Ident)]
outPairs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p) forall a b. (a -> b) -> a -> b
$ \(VName
outVar, TypeBase Shape NoUniqueness
t) -> do
    VName
outVar' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
outVar forall a. [a] -> [a] -> [a]
++ [Char]
"_elem"
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
outVar, VName -> TypeBase Shape NoUniqueness -> Ident
Ident VName
outVar' TypeBase Shape NoUniqueness
t)

  let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
        let (Lambda SOACS
res_lam, [Input]
new_inp) = forall {k} (rep :: k).
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
            ([VName]
extra_nms, [TypeBase Shape NoUniqueness]
extra_rtps) =
              forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
                forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
                  forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars forall a b. (a -> b) -> a -> b
$
                    forall a b. (a -> b) -> [a] -> [b]
map (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1) forall a b. (a -> b) -> a -> b
$
                      forall {k} (rep :: k). SOAC rep -> [TypeBase Shape NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p
            res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
res_lam forall a. [a] -> [a] -> [a]
++ [TypeBase Shape NoUniqueness]
extra_rtps}
         in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
horizFuse Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms forall a b. (a -> b) -> a -> b
$ FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker)) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Horizontal fusion is invalid in the presence of output transforms."

  case (SOAC SOACS
soac_c, SOAC SOACS
soac_p) of
    (SOAC SOACS, SOAC SOACS)
_ | forall {k} (rep :: k). SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_p forall a. Eq a => a -> a -> Bool
/= forall {k} (rep :: k). SOAC rep -> SubExp
SOAC.width SOAC SOACS
soac_c -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC widths must match."
    ( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
      SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_
      )
        | ([VName], [VName]) -> FusedSOAC -> Bool
scremaFusionOK (forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_p forall a. Num a => a -> a -> a
+ forall {k} (rep :: k). [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedSOAC
ker -> do
            let red_nes_p :: [SubExp]
red_nes_p = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
                red_nes_c :: [SubExp]
red_nes_c = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
                scan_nes_p :: [SubExp]
scan_nes_p = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
                scan_nes_c :: [SubExp]
scan_nes_c = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
                (Lambda SOACS
res_lam', [Input]
new_inp) =
                  forall {k} (rep :: k).
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
                    Names
unfus_set
                    [VName]
outVars
                    Lambda SOACS
lam_p
                    [SubExp]
scan_nes_p
                    [SubExp]
red_nes_p
                    [Input]
inp_p_arr
                    [(VName, Ident)]
outPairs
                    Lambda SOACS
lam_c
                    [SubExp]
scan_nes_c
                    [SubExp]
red_nes_c
                    [Input]
inp_c_arr
                ([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
                  forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
                ([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
                  forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker
                unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success
              ( [VName]
soac_p_scanout
                  forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
                  forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
                  forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
                  forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
                  forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
              )
              forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma
                SubExp
w
                (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm ([Scan SOACS]
scans_p forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
                [Input]
new_inp

    ------------------
    -- Scatter fusion --
    ------------------

    -- Map-Scatter fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Scatter is not writing to any array used in
    -- the Map.
    ( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
      )
        | forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the scatter, i.e., not used elsewhere.
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
            let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests

    -- Map-Hist fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Hist is not writing to any array used in
    -- the Map.
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
      )
        | forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the hist, i.e., not used elsewhere.
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
unfus_set) [VName]
outVars,
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedSOAC -> Bool
mapWriteFusionOK [VName]
outVars FusedSOAC
ker -> do
            let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp

    -- Hist-Hist fusion
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
      SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_
      )
        | Bool
horizFuse -> do
            let p_num_buckets :: Int
p_num_buckets = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
                c_num_buckets :: Int
c_num_buckets = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
                (Body SOACS
body_p, Body SOACS
body_c) = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
                body' :: Body SOACS
body' =
                  Body
                    { bodyDec :: BodyDec SOACS
bodyDec = forall {k} (rep :: k). Body rep -> BodyDec rep
bodyDec Body SOACS
body_p, -- body_p and body_c have the same decorations
                      bodyStms :: Stms SOACS
bodyStms = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body_p forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body_c,
                      bodyResult :: Result
bodyResult =
                        forall a. Int -> [a] -> [a]
take Int
c_num_buckets (forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body_c)
                          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take Int
p_num_buckets (forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body_p)
                          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body_c)
                          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body_p)
                    }
                lam' :: Lambda SOACS
lam' =
                  Lambda
                    { lambdaParams :: [LParam SOACS]
lambdaParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
                      lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
                      lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType =
                        forall a. Int -> a -> [a]
replicate (Int
c_num_buckets forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
                          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c)
                          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
                    }
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)

    -- Scatter-write fusion.
    ( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
      SOAC.Scatter SubExp
_len_p Lambda SOACS
_lam_p [Input]
ivs_p [(Shape, Int, VName)]
as_p
      )
        | Bool
horizFuse -> do
            let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices forall a. [a] -> [a] -> [a]
++ [a]
ys_indices forall a. [a] -> [a] -> [a]
++ [a]
xs_vals forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
                  where
                    ([a]
xs_indices, [a]
xs_vals) = forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
                    ([a]
ys_indices, [a]
ys_vals) = forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
            let (Body SOACS
body_p, Body SOACS
body_c) = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_p, forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_c)
            let body' :: Body SOACS
body' =
                  Body
                    { bodyDec :: BodyDec SOACS
bodyDec = forall {k} (rep :: k). Body rep -> BodyDec rep
bodyDec Body SOACS
body_p, -- body_p and body_c have the same decorations
                      bodyStms :: Stms SOACS
bodyStms = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body_p forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body_c,
                      bodyResult :: Result
bodyResult = forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body_c) [(Shape, Int, VName)]
as_p (forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body_p)
                    }
            let lam' :: Lambda SOACS
lam' =
                  Lambda
                    { lambdaParams :: [LParam SOACS]
lambdaParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
                      lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body',
                      lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = forall {array} {a} {array}.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
                    }
            [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
    (SOAC.Scatter {}, SOAC SOACS
_) ->
      forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a write with anything else than a write or a map"
    (SOAC SOACS
_, SOAC.Scatter {}) ->
      forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse a write with anything else than a write or a map"
    ----------------------------
    -- Stream-Stream Fusions: --
    ----------------------------
    (SOAC.Stream {}, SOAC.Stream {}) -> do
      -- fuse two SEQUENTIAL streams
      ([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
      [VName] -> SOAC SOACS -> TryFusion FusedSOAC
success [VName]
res_nms SOAC SOACS
res_stream
    -------------------------------------------------------------------
    --- If one is a stream, translate the other to a stream as well.---
    --- This does not get in trouble (infinite computation) because ---
    ---   scan's translation to Stream introduces a hindrance to    ---
    ---   (horizontal fusion), hence repeated application is for the---
    ---   moment impossible. However, if with a dependence-graph rep---
    ---   we could run in an infinite recursion, i.e., repeatedly   ---
    ---   fusing map o scan into an infinity of Stream levels!      ---
    -------------------------------------------------------------------
    (SOAC.Stream {}, SOAC SOACS
_) -> do
      -- If this rule is matched then soac_p is NOT a stream.
      -- To fuse a stream kernel, we transform soac_p to a stream, which
      -- borrows the sequential/parallel property of the soac_c Stream,
      -- and recursively perform stream-stream fusion.
      (SOAC SOACS
soac_p', [Ident]
newacc_ids) <- forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
      Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
fuseSOACwithKer
        ([VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
        (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
        SOAC SOACS
soac_p'
        FusedSOAC
ker
    (SOAC SOACS
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) | Just ([Scan SOACS], Lambda SOACS)
_ <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm SOACS
form -> do
      -- A Scan soac can be currently only fused as a (sequential) stream,
      -- hence it is first translated to a (sequential) Stream and then
      -- fusion with a kernel is attempted.
      (SOAC SOACS
soac_p', [Ident]
newacc_ids) <- forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
      if SOAC SOACS
soac_p' forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_p
        then
          Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
fuseSOACwithKer
            ([VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
            (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids forall a. [a] -> [a] -> [a]
++ [VName]
outVars)
            SOAC SOACS
soac_p'
            FusedSOAC
ker
        else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."
    (SOAC SOACS
_, SOAC.Stream {}) -> do
      -- If it reached this case then soac_c is NOT a Stream kernel,
      -- hence transform the kernel's soac to a stream and attempt
      -- stream-stream fusion recursivelly.
      -- The newly created stream corresponding to soac_c borrows the
      -- sequential/parallel property of the soac_p stream.
      (SOAC SOACS
soac_c', [Ident]
newacc_ids) <- forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC SOACS
soac_c
      if SOAC SOACS
soac_c' forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_c
        then
          Names -> [VName] -> SOAC SOACS -> FusedSOAC -> TryFusion FusedSOAC
fuseSOACwithKer
            ([VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids) forall a. Semigroup a => a -> a -> a
<> Names
unfus_set)
            [VName]
outVars
            SOAC SOACS
soac_p
            forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac_c', fsOutNames :: [VName]
fsOutNames = forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids forall a. [a] -> [a] -> [a]
++ FusedSOAC -> [VName]
fsOutNames FusedSOAC
ker}
        else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"SOAC could not be turned into stream."

    ---------------------------------
    --- DEFAULT, CANNOT FUSE CASE ---
    ---------------------------------
    (SOAC SOACS, SOAC SOACS)
_ -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot fuse"

fuseStreamHelper ::
  [VName] ->
  Names ->
  [VName] ->
  [(VName, Ident)] ->
  SOAC ->
  SOAC ->
  TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper
  [VName]
out_kernms
  Names
unfus_set
  [VName]
outVars
  [(VName, Ident)]
outPairs
  (SOAC.Stream SubExp
w2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
  (SOAC.Stream SubExp
_ Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) = do
    -- very similar to redomap o redomap composition, but need
    -- to remove first the `chunk' parameters of streams'
    -- lambdas and put them in the resulting stream lambda.
    let chunk1 :: Param (TypeBase Shape NoUniqueness)
chunk1 = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1
        chunk2 :: Param (TypeBase Shape NoUniqueness)
chunk2 = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2
        hmnms :: Map VName VName
hmnms = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk2, forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk1)]
        lam20 :: Lambda SOACS
lam20 = forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
        lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1}
        lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam20}
        (Lambda SOACS
res_lam', [Input]
new_inp) =
          forall {k} (rep :: k).
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
            Names
unfus_set
            [VName]
outVars
            Lambda SOACS
lam1'
            []
            [SubExp]
nes1
            [Input]
inp1_arr
            [(VName, Ident)]
outPairs
            Lambda SOACS
lam2'
            []
            [SubExp]
nes2
            [Input]
inp2_arr
        res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param (TypeBase Shape NoUniqueness)
chunk1 forall a. a -> [a] -> [a]
: forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
res_lam'}
        unfus_accs :: [VName]
unfus_accs = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
        unfus_arrs :: [VName]
unfus_arrs = forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
unfus_accs) forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [VName]
unfus_accs forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
        forall {k} (rep :: k).
SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w2 Lambda SOACS
res_lam'' ([SubExp]
nes1 forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
      )
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC SOACS
_ SOAC SOACS
_ = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot Fuse Streams!"

-- Here follows optimizations and transforms to expose fusability.

optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC
optimizeKernel Maybe [VName]
inp FusedSOAC
ker = do
  (SOAC SOACS
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) (FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ FusedSOAC
ker {fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac, fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
resTrans}

optimizeSOAC ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC SOACS
soac ArrayTransforms
os = do
  (Bool, SOAC SOACS, ArrayTransforms)
res <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
False, SOAC SOACS
soac, ArrayTransforms
os) [Maybe [VName]
 -> SOAC SOACS
 -> ArrayTransforms
 -> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations
  case (Bool, SOAC SOACS, ArrayTransforms)
res of
    (Bool
False, SOAC SOACS
_, ArrayTransforms
_) -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"No optimisation applied"
    (Bool
True, SOAC SOACS
soac', ArrayTransforms
os') -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
os')
  where
    comb :: (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f =
      do
        (SOAC SOACS
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f Maybe [VName]
inp SOAC SOACS
soac' ArrayTransforms
os
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
True, SOAC SOACS
soac'', ArrayTransforms
os'')
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os')

type Optimization =
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)

optimizations :: [Optimization]
optimizations :: [Maybe [VName]
 -> SOAC SOACS
 -> ArrayTransforms
 -> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim]

iswim ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
  | Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- forall {k} (rep :: k). ScremaForm rep -> Maybe [Scan rep]
Futhark.isScanSOAC ScremaForm SOACS
form,
    Just (Pat (TypeBase Shape NoUniqueness)
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS
-> Maybe
     (Pat (TypeBase Shape NoUniqueness), Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
    Just [VName]
nes_names <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
      let nes_idents :: [Ident]
nes_idents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase Shape NoUniqueness -> Ident
Ident [VName]
nes_names forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
          map_nes :: [Input]
map_nes = forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
          map_arrs' :: [Input]
map_arrs' = [Input]
map_nes forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
          ([Param (TypeBase Shape NoUniqueness)]
scan_acc_params, [Param (TypeBase Shape NoUniqueness)]
scan_elem_params) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
scan_fun
          map_params :: [Param (TypeBase Shape NoUniqueness)]
map_params =
            forall a b. (a -> b) -> [a] -> [b]
map LParam SOACS -> LParam SOACS
removeParamOuterDim [Param (TypeBase Shape NoUniqueness)]
scan_acc_params
              forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param (TypeBase Shape NoUniqueness)]
scan_elem_params
          map_rettype :: [TypeBase Shape NoUniqueness]
map_rettype = forall a b. (a -> b) -> [a] -> [b]
map (forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun

          scan_params :: [LParam SOACS]
scan_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
          scan_body :: Body SOACS
scan_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
map_fun
          scan_rettype :: [TypeBase Shape NoUniqueness]
scan_rettype = forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun
          scan_fun' :: Lambda SOACS
scan_fun' = forall {k} (rep :: k).
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [LParam SOACS]
scan_params Body SOACS
scan_body [TypeBase Shape NoUniqueness]
scan_rettype
          nes' :: [SubExp]
nes' = forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params
          arrs' :: [VName]
arrs' = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
map_params

      ScremaForm SOACS
scan_form <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']

      let map_body :: Body SOACS
map_body =
            forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
              ( forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w Pat (TypeBase Shape NoUniqueness)
map_pat) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
                    forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                      forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
              )
              forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes
              forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (TypeBase Shape NoUniqueness)
map_pat
          map_fun' :: Lambda SOACS
map_fun' = forall {k} (rep :: k).
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [Param (TypeBase Shape NoUniqueness)]
map_params Body SOACS
map_body [TypeBase Shape NoUniqueness]
map_rettype
          perm :: [Int]
perm = case forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun of -- instead of map_fun
            [] -> []
            TypeBase Shape NoUniqueness
t : [TypeBase Shape NoUniqueness]
_ -> Int
1 forall a. a -> [a] -> [a]
: Int
0 forall a. a -> [a] -> [a]
: [Int
2 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t]

      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall {k} (rep :: k).
SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
map_w (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
          ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
map_cs [Int]
perm
        )
iswim Maybe [VName]
_ SOAC SOACS
_ ArrayTransforms
_ =
  forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"ISWIM does not apply."

removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
  let t :: TypeBase Shape NoUniqueness
t = forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam SOACS
param
   in LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}

setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
  let t :: TypeBase Shape NoUniqueness
t = forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam SOACS
param forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
   in LParam SOACS
param {paramDec :: TypeBase Shape NoUniqueness
paramDec = TypeBase Shape NoUniqueness
t}

setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type
setPatOuterDimTo :: SubExp
-> Pat (TypeBase Shape NoUniqueness)
-> Pat (TypeBase Shape NoUniqueness)
setPatOuterDimTo SubExp
w = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)

-- Now for fiddling with transpositions...

commonTransforms ::
  [VName] ->
  [SOAC.Input] ->
  (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
  where
    inps' :: [(Bool, Input)]
inps' =
      [ (Input -> VName
SOAC.inputArray Input
inp forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
        | Input
inp <- [Input]
inps
      ]

commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
  case forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
    Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot SOAC.<|) forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
    Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
  where
    inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
      case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
        (Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> forall a. a -> Maybe a
Just (forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
          | ArrayTransform
ot1 forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> forall a. a -> Maybe a
Just (forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> forall a. Maybe a
Nothing
    inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)

mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
  forall a. Ord a => a -> a -> a
min Int
resDims (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) forall a. Num a => a -> a -> a
+ Int
1
  where
    resDims :: Int
resDims = forall {shape} {u}. ArrayShape shape => [TypeBase shape u] -> Int
minDim forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
      [] -> forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      Nesting SOACS
nest : [Nesting SOACS]
_ -> forall {k} (rep :: k). Nesting rep -> [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType Nesting SOACS
nest
    minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
    minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Ord a => a -> a -> a
min (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts

pullRearrange ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange SOAC SOACS
soac ArrayTransforms
ots = do
  MapNest
nest <- forall a. Maybe a -> TryFusion a
liftMaybe forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
  SOAC.Rearrange Certs
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
  if [Int] -> Int
rearrangeReach [Int]
perm forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let -- Expand perm to cover the full extent of the input dimensionality
          perm' :: Input -> [Int]
perm' Input
inp = forall a. Int -> [a] -> [a]
take Int
r [Int]
perm forall a. [a] -> [a] -> [a]
++ [forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r forall a. Num a => a -> a -> a
- Int
1]
            where
              r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
          addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
cs forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
          inputs' :: [Input]
inputs' = forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). MapNest rep -> [Input]
MapNest.inputs MapNest
nest
      SOAC SOACS
soac' <-
        forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC forall a b. (a -> b) -> a -> b
$
          [Input]
inputs' forall {k} (rep :: k). [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
    else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull transpose"

pushRearrange ::
  [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC SOACS
soac ArrayTransforms
ots = do
  MapNest
nest <- forall a. Maybe a -> TryFusion a
liftMaybe forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC SOACS
soac
  ([Int]
perm, [Input]
inputs') <- forall a. Maybe a -> TryFusion a
liftMaybe forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). MapNest rep -> [Input]
MapNest.inputs MapNest
nest
  if [Int] -> Int
rearrangeReach [Int]
perm forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let invertRearrange :: ArrayTransform
invertRearrange = Certs -> [Int] -> ArrayTransform
SOAC.Rearrange forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
      SOAC SOACS
soac' <-
        forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC forall a b. (a -> b) -> a -> b
$
          [Input]
inputs'
            forall {k} (rep :: k). [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
    else forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot push transpose"

-- | Actually also rearranges indices.
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
  forall {k} (rep :: k).
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest.MapNest
    SubExp
w
    Lambda SOACS
body
    ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
        forall {k} {k} {rep :: k} {rep :: k}.
Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType
        [Nesting SOACS]
nestings
        forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
1
        forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [TypeBase Shape NoUniqueness]
ts
    )
    [Input]
inps
  where
    origts :: [TypeBase Shape NoUniqueness]
origts = forall {k} (rep :: k). MapNest rep -> [TypeBase Shape NoUniqueness]
MapNest.typeOf MapNest
nest
    -- The permutation may be deeper than the rank of the type,
    -- but it is required that it is an identity permutation
    -- beyond that.  This is supposed to be checked as an
    -- invariant by whoever calls rearrangeReturnTypes.
    rearrangeType' :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' TypeBase Shape NoUniqueness
t = [Int] -> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType (forall a. Int -> [a] -> [a]
take (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) [Int]
perm) TypeBase Shape NoUniqueness
t
    ts :: [TypeBase Shape NoUniqueness]
ts = forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
rearrangeType' [TypeBase Shape NoUniqueness]
origts

    setReturnType :: Nesting rep -> [TypeBase Shape NoUniqueness] -> Nesting rep
setReturnType Nesting rep
nesting [TypeBase Shape NoUniqueness]
t' =
      Nesting rep
nesting {nestingReturnType :: [TypeBase Shape NoUniqueness]
MapNest.nestingReturnType = [TypeBase Shape NoUniqueness]
t'}

fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
  case forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
    [Int]
perm : [[Int]]
_ -> do
      [Input]
inps' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [Input]
inps')
    [[Int]]
_ -> forall a. Maybe a
Nothing
  where
    exposable :: Input -> Bool
exposable = (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray

    inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
      | ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certs
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = forall a. a -> Maybe a
Just [Int]
perm
    inputRearrange Input
_ = forall a. Maybe a
Nothing

    fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
      | Int
r <- Input -> Int
SOAC.inputRank Input
inp,
        Int
r forall a. Ord a => a -> a -> Bool
>= Int
d =
          forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
      | Bool
otherwise = forall a. Maybe a
Nothing

pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
  | Just Lambda SOACS
maplam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
    SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam = do
      let mapw' :: SubExp
mapw' = case forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
            [] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
            SubExp
d : [SubExp]
_ -> SubExp
d
          trInput :: Input -> Input
trInput Input
inp
            | forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Input -> TypeBase Shape NoUniqueness
SOAC.inputType Input
inp) forall a. Eq a => a -> a -> Bool
== Int
1 =
                ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.Reshape Certs
cs ReshapeKind
k Shape
shape) Input
inp
            | Bool
otherwise =
                ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
SOAC.ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) Input
inp
          inputs' :: [Input]
inputs' = forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
trInput [Input]
inps
          inputTypes :: [TypeBase Shape NoUniqueness]
inputTypes = forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase Shape NoUniqueness
SOAC.inputType [Input]
inputs'

      let outersoac ::
            ([SOAC.Input] -> SOAC) ->
            (SubExp, [SubExp]) ->
            TryFusion ([SOAC.Input] -> SOAC)
          outersoac :: ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac [Input] -> SOAC SOACS
inner (SubExp
w, [SubExp]
outershape) = do
            let addDims :: TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims TypeBase Shape NoUniqueness
t = forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape NoUniqueness
t (forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
                retTypes :: [TypeBase Shape NoUniqueness]
retTypes = forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
addDims forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
maplam

            [Param (TypeBase Shape NoUniqueness)]
ps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
inputTypes forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
inpt ->
              forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"pullReshape_param" forall a b. (a -> b) -> a -> b
$
                forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shape forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) TypeBase Shape NoUniqueness
inpt

            Body SOACS
inner_body <-
              forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC SOACS
inner forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param (TypeBase Shape NoUniqueness)]
ps]
            let inner_fun :: Lambda SOACS
inner_fun =
                  Lambda
                    { lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
ps,
                      lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
retTypes,
                      lambdaBody :: Body SOACS
lambdaBody = Body SOACS
inner_body
                    }
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
inner_fun

      [Input] -> SOAC SOACS
op' <-
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac (forall {k} (rep :: k).
SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
mapw' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
maplam) forall a b. (a -> b) -> a -> b
$
          forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape) forall a b. (a -> b) -> a -> b
$
            forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [[a]]
tails forall a b. (a -> b) -> a -> b
$
              forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Input] -> SOAC SOACS
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC SOACS
_ ArrayTransforms
_ = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull reshape"

-- Tie it all together in exposeInputs (for making inputs to a
-- consumer available) and pullOutputTransforms (for moving
-- output-transforms of a producer to its inputs instead).

exposeInputs ::
  [VName] ->
  FusedSOAC ->
  TryFusion (FusedSOAC, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs [VName]
inpIds FusedSOAC
ker =
  (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pushRearrange')
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedSOAC
pullRearrange')
    forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker
  where
    ot :: ArrayTransforms
ot = FusedSOAC -> ArrayTransforms
fsOutputTransform FusedSOAC
ker

    pushRearrange' :: TryFusion FusedSOAC
pushRearrange' = do
      (SOAC SOACS
soac', ArrayTransforms
ot') <- [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        FusedSOAC
ker
          { fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
            fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
ot'
          }

    pullRearrange' :: TryFusion FusedSOAC
pullRearrange' = do
      (SOAC SOACS
soac', ArrayTransforms
ot') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange (FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker) ArrayTransforms
ot
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"pullRearrange was not enough"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        FusedSOAC
ker
          { fsSOAC :: SOAC SOACS
fsSOAC = SOAC SOACS
soac',
            fsOutputTransform :: ArrayTransforms
fsOutputTransform = ArrayTransforms
SOAC.noTransforms
          }

    exposeInputs' :: FusedSOAC -> TryFusion (FusedSOAC, ArrayTransforms)
exposeInputs' FusedSOAC
ker' =
      case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds forall a b. (a -> b) -> a -> b
$ FusedSOAC -> [Input]
inputs FusedSOAC
ker' of
        (ArrayTransforms
ot', [Input]
inps')
          | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (FusedSOAC
ker' {fsSOAC :: SOAC SOACS
fsSOAC = [Input]
inps' forall {k} (rep :: k). [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedSOAC -> SOAC SOACS
fsSOAC FusedSOAC
ker'}, ArrayTransforms
ot')
        (ArrayTransforms, [Input])
_ -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot expose"

    exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ TypeBase Shape NoUniqueness
_)
      | ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
    exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds

outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange, SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape]

pullOutputTransforms ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms = forall {t} {t}.
[t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers
  where
    attempt :: [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [] t
_ t
_ = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Cannot pull anything"
    attempt (t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps) t
soac t
ots =
      do
        (SOAC SOACS
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p t
soac t
ots
        if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
SOAC.noTransforms)
          else SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac' ArrayTransforms
ots' forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC SOACS
soac', ArrayTransforms
ots')
        forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps t
soac t
ots