{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.Fusion.LoopKernel
  ( FusedKer (..),
    newKernel,
    inputs,
    setInputs,
    arrInputs,
    transformOutput,
    attemptFusion,
    SOAC,
    MapNest,
  )
where

import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.HORep.MapNest as MapNest
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)

newtype TryFusion a
  = TryFusion
      ( ReaderT
          (Scope SOACS)
          (StateT VNameSource Maybe)
          a
      )
  deriving
    ( a -> TryFusion b -> TryFusion a
(a -> b) -> TryFusion a -> TryFusion b
(forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
      Functor TryFusion
a -> TryFusion a
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
    (a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
TryFusion a -> TryFusion b -> TryFusion b
TryFusion a -> TryFusion b -> TryFusion a
TryFusion (a -> b) -> TryFusion a -> TryFusion b
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: (a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
$cp1Applicative :: Functor TryFusion
Applicative,
      Applicative TryFusion
TryFusion a
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
TryFusion a -> TryFusion a -> TryFusion a
TryFusion a -> TryFusion [a]
TryFusion a -> TryFusion [a]
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: TryFusion a
$cempty :: forall a. TryFusion a
$cp1Alternative :: Applicative TryFusion
Alternative,
      Applicative TryFusion
a -> TryFusion a
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
TryFusion a -> (a -> TryFusion b) -> TryFusion b
TryFusion a -> TryFusion b -> TryFusion b
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$cp1Monad :: Applicative TryFusion
Monad,
      Monad TryFusion
Monad TryFusion
-> (forall a. String -> TryFusion a) -> MonadFail TryFusion
String -> TryFusion a
forall a. String -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> TryFusion a
$cfail :: forall a. String -> TryFusion a
$cp1MonadFail :: Monad TryFusion
MonadFail,
      Monad TryFusion
Applicative TryFusion
TryFusion VNameSource
Applicative TryFusion
-> Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
$cp2MonadFreshNames :: Monad TryFusion
$cp1MonadFreshNames :: Applicative TryFusion
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

tryFusion ::
  MonadFreshNames m =>
  TryFusion a ->
  Scope SOACS ->
  m (Maybe a)
tryFusion :: TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
    Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
    Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)

liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = String -> TryFusion a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

type SOAC = SOAC.SOAC SOACS

type MapNest = MapNest.MapNest SOACS

-- XXX: This function is very gross.
transformOutput ::
  SOAC.ArrayTransforms ->
  [VName] ->
  [Ident] ->
  Binder SOACS ()
transformOutput :: ArrayTransforms -> [VName] -> [Ident] -> Binder SOACS ()
transformOutput ArrayTransforms
ts [VName]
names = ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts
  where
    descend :: ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts' [Ident]
validents =
      case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts' of
        ViewF
SOAC.EmptyF ->
          [(VName, Ident)]
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Ident]
validents) (((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ())
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
k, Ident
valident) ->
            [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
k] (Exp (Lore (BinderT SOACS (State VNameSource))) -> Binder SOACS ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
valident
        ArrayTransform
t SOAC.:< ArrayTransforms
ts'' -> do
          let ([BasicOp]
es, [Certificates]
css) = [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BasicOp, Certificates)] -> ([BasicOp], [Certificates]))
-> [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. (a -> b) -> a -> b
$ (Ident -> (BasicOp, Certificates))
-> [Ident] -> [(BasicOp, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform ArrayTransform
t) [Ident]
validents
              mkPat :: Ident -> PatternT Type
mkPat (Ident VName
nm Type
tp) = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
nm Type
tp]
          [Type]
opts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type])
-> BinderT SOACS (State VNameSource) [[Type]]
-> BinderT SOACS (State VNameSource) [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BasicOp -> BinderT SOACS (State VNameSource) [Type])
-> [BasicOp] -> BinderT SOACS (State VNameSource) [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BasicOp -> BinderT SOACS (State VNameSource) [Type]
forall lore (m :: * -> *). HasScope lore m => BasicOp -> m [Type]
primOpType [BasicOp]
es
          [Ident]
newIds <- [(VName, Type)]
-> ((VName, Type) -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Type]
opts) (((VName, Type) -> BinderT SOACS (State VNameSource) Ident)
 -> BinderT SOACS (State VNameSource) [Ident])
-> ((VName, Type) -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ \(VName
k, Type
opt) ->
            String -> Type -> BinderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
k) Type
opt
          [(Certificates, Ident, BasicOp)]
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Certificates]
-> [Ident] -> [BasicOp] -> [(Certificates, Ident, BasicOp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Certificates]
css [Ident]
newIds [BasicOp]
es) (((Certificates, Ident, BasicOp) -> Binder SOACS ())
 -> Binder SOACS ())
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(Certificates
cs, Ident
ids, BasicOp
e) ->
            Certificates -> Binder SOACS () -> Binder SOACS ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder SOACS () -> Binder SOACS ())
-> Binder SOACS () -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS (State VNameSource)))
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind (Ident -> PatternT Type
mkPat Ident
ids) (BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
e)
          ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts'' [Ident]
newIds

applyTransform :: SOAC.ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform :: ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform (SOAC.Rearrange Certificates
cs [Int]
perm) Ident
v =
  ([Int] -> VName -> BasicOp
Rearrange [Int]
perm' (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
  where
    perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm) [Int
0 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> Type
identType Ident
v) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
applyTransform (SOAC.Reshape Certificates
cs ShapeChange SubExp
shape) Ident
v =
  (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) Ident
v =
  let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp) -> Shape -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
   in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeInner Certificates
cs ShapeChange SubExp
shape) Ident
v =
  let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp) -> Shape -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
   in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.Replicate Certificates
cs Shape
n) Ident
v =
  (Shape -> SubExp -> BasicOp
Replicate Shape
n (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)

inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia Type
iat) =
  case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
    ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts' VName
ia Type
iat)
    ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing

data FusedKer = FusedKer
  { -- | the SOAC expression, e.g., mapT( f(a,b), x, y )
    FusedKer -> SOAC
fsoac :: SOAC,
    -- | Variables used in in-place updates in the kernel itself, as
    -- well as on the path to the kernel from the current position.
    -- This is used to avoid fusion that would violate in-place
    -- restrictions.
    FusedKer -> Names
inplace :: Names,
    -- | whether at least a fusion has been performed.
    FusedKer -> [VName]
fusedVars :: [VName],
    -- | The set of variables that were consumed by the SOACs
    -- contributing to this kernel.  Note that, by the type rules, the
    -- final SOAC may actually consume _more_ than its original
    -- contributors, which implies the need for 'Copy' expressions.
    FusedKer -> Names
fusedConsumed :: Names,
    -- | The names in scope at the kernel.
    FusedKer -> Scope SOACS
kernelScope :: Scope SOACS,
    FusedKer -> ArrayTransforms
outputTransform :: SOAC.ArrayTransforms,
    FusedKer -> [VName]
outNames :: [VName],
    FusedKer -> StmAux ()
kerAux :: StmAux ()
  }
  deriving (Int -> FusedKer -> ShowS
[FusedKer] -> ShowS
FusedKer -> String
(Int -> FusedKer -> ShowS)
-> (FusedKer -> String) -> ([FusedKer] -> ShowS) -> Show FusedKer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FusedKer] -> ShowS
$cshowList :: [FusedKer] -> ShowS
show :: FusedKer -> String
$cshow :: FusedKer -> String
showsPrec :: Int -> FusedKer -> ShowS
$cshowsPrec :: Int -> FusedKer -> ShowS
Show)

newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope =
  FusedKer :: SOAC
-> Names
-> [VName]
-> Names
-> Scope SOACS
-> ArrayTransforms
-> [VName]
-> StmAux ()
-> FusedKer
FusedKer
    { fsoac :: SOAC
fsoac = SOAC
soac,
      inplace :: Names
inplace = Names
consumed,
      fusedVars :: [VName]
fusedVars = [],
      fusedConsumed :: Names
fusedConsumed = Names
consumed,
      outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms,
      outNames :: [VName]
outNames = [VName]
out_nms,
      kernelScope :: Scope SOACS
kernelScope = Scope SOACS
scope,
      kerAux :: StmAux ()
kerAux = StmAux ()
aux
    }

arrInputs :: FusedKer -> S.Set VName
arrInputs :: FusedKer -> Set VName
arrInputs = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (FusedKer -> [VName]) -> FusedKer -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName])
-> (FusedKer -> [Input]) -> FusedKer -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [Input]
inputs

inputs :: FusedKer -> [SOAC.Input]
inputs :: FusedKer -> [Input]
inputs = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs (SOAC -> [Input]) -> (FusedKer -> SOAC) -> FusedKer -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> SOAC
fsoac

setInputs :: [SOAC.Input] -> FusedKer -> FusedKer
setInputs :: [Input] -> FusedKer -> FusedKer
setInputs [Input]
inps FusedKer
ker = FusedKer
ker {fsoac :: SOAC
fsoac = [Input]
inps [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC
fsoac FusedKer
ker}

tryOptimizeSOAC ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
tryOptimizeSOAC :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
  (SOAC
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC
soac ArrayTransforms
forall a. Monoid a => a
mempty
  let ker' :: FusedKer
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedKer -> [Input]
inputs FusedKer
ker) [Input] -> FusedKer -> FusedKer
`setInputs` FusedKer
ker
      outIdents :: [Ident]
outIdents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
outVars ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac'
      ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
  Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac' Names
consumed FusedKer
ker''
  where
    addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
      | Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
        ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
      | Bool
otherwise =
        Input
inp

tryOptimizeKernel ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
tryOptimizeKernel :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
  FusedKer
ker' <- Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedKer
ker
  Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker'

tryExposeInputs ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
tryExposeInputs :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
  (FusedKer
ker', ArrayTransforms
ots) <- [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
outVars FusedKer
ker
  if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
    then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker'
    else do
      (SOAC
soac', ArrayTransforms
ots') <- SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms SOAC
soac ArrayTransforms
ots
      let outIdents :: [Ident]
outIdents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
outVars ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac'
          ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
      if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
        then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac' Names
consumed FusedKer
ker''
        else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"tryExposeInputs could not pull SOAC transforms"

fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker =
  FusedKer
ker {fsoac :: SOAC
fsoac = SOAC -> SOAC
fixInputTypes' (SOAC -> SOAC) -> SOAC -> SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker}
  where
    fixInputTypes' :: SOAC -> SOAC
fixInputTypes' SOAC
soac =
      (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac) [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC
soac
    fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v Type
_)
      | Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
        ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v (Type -> Input) -> Type -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v'
    fixInputType Input
inp = Input
inp

applyFusionRules ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
applyFusionRules :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker =
  Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
    TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
    TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
    TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker

attemptFusion ::
  MonadFreshNames m =>
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  m (Maybe FusedKer)
attemptFusion :: Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker =
  (FusedKer -> FusedKer) -> Maybe FusedKer -> Maybe FusedKer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FusedKer -> FusedKer
removeUnusedParamsFromKer
    (Maybe FusedKer -> Maybe FusedKer)
-> m (Maybe FusedKer) -> m (Maybe FusedKer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TryFusion FusedKer -> Scope SOACS -> m (Maybe FusedKer)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion
      (Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker)
      (FusedKer -> Scope SOACS
kernelScope FusedKer
ker)

removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer FusedKer
ker =
  case SOAC
soac of
    SOAC.Screma {} -> FusedKer
ker {fsoac :: SOAC
fsoac = SOAC
soac'}
    SOAC
_ -> FusedKer
ker
  where
    soac :: SOAC
soac = FusedKer -> SOAC
fsoac FusedKer
ker
    l :: Lambda SOACS
l = SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
soac
    inps :: [Input]
inps = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
    (Lambda SOACS
l', [Input]
inps') = Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps
    soac' :: SOAC
soac' =
      Lambda SOACS
l'
        Lambda SOACS -> SOAC -> SOAC
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` ([Input]
inps' [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC
soac)

removeUnusedParams :: Lambda -> [SOAC.Input] -> (Lambda, [SOAC.Input])
removeUnusedParams :: Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps =
  (Lambda SOACS
l {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
ps'}, [Input]
inps')
  where
    pInps :: [(Param Type, Input)]
pInps = [Param Type] -> [Input] -> [(Param Type, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
l) [Input]
inps
    ([Param Type]
ps', [Input]
inps') = case ([(Param Type, Input)] -> ([Param Type], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, Input)] -> ([Param Type], [Input]))
-> [(Param Type, Input)] -> ([Param Type], [Input])
forall a b. (a -> b) -> a -> b
$ ((Param Type, Input) -> Bool)
-> [(Param Type, Input)] -> [(Param Type, Input)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param Type -> Bool
used (Param Type -> Bool)
-> ((Param Type, Input) -> Param Type)
-> (Param Type, Input)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, Input) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, Input)]
pInps, [(Param Type, Input)]
pInps) of
      (([], []), (Param Type
p, Input
inp) : [(Param Type, Input)]
_) -> ([Param Type
p], [Input
inp])
      (([Param Type]
ps_, [Input]
inps_), [(Param Type, Input)]
_) -> ([Param Type]
ps_, [Input]
inps_)
    used :: Param Type -> Bool
used Param Type
p = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> Names -> Bool
`nameIn` Names
freeVars
    freeVars :: Names
freeVars = BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT SOACS -> Names) -> BodyT SOACS -> Names
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
l

-- | Check that the consumer uses at least one output of the producer
-- unmodified.
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
  where
    inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)

-- | Check that the consumer uses all the outputs of the producer unmodified.
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
  where
    inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)

-- | The brain of this module: Fusing a SOAC with a Kernel.
fuseSOACwithKer ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
fuseSOACwithKer :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC
soac_p Names
soac_p_consumed FusedKer
ker = do
  -- We are fusing soac_p into soac_c, i.e, the output of soac_p is going
  -- into soac_c.
  let soac_c :: SOAC
soac_c = FusedKer -> SOAC
fsoac FusedKer
ker
      inp_p_arr :: [Input]
inp_p_arr = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac_p
      horizFuse :: Bool
horizFuse =
        Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty
          Bool -> Bool -> Bool
&& SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_c
      inp_c_arr :: [Input]
inp_c_arr = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac_c
      lam_p :: Lambda SOACS
lam_p = SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
soac_p
      lam_c :: Lambda SOACS
lam_c = SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
soac_c
      w :: SubExp
w = SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_p
      returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
      success :: [VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_outnms SOAC
res_soac = do
        let fusedVars_new :: [VName]
fusedVars_new = FusedKer -> [VName]
fusedVars FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars
        -- Avoid name duplication, because the producer lambda is not
        -- removed from the program until much later.
        Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
res_soac
        FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
          FusedKer
ker
            { fsoac :: SOAC
fsoac = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC -> SOAC
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` SOAC
res_soac,
              fusedVars :: [VName]
fusedVars = [VName]
fusedVars_new,
              inplace :: Names
inplace = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
              fusedConsumed :: Names
fusedConsumed = FusedKer -> Names
fusedConsumed FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
              outNames :: [VName]
outNames = [VName]
res_outnms
            }

  [(VName, Ident)]
outPairs <- [(VName, Type)]
-> ((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac_p) (((VName, Type) -> TryFusion (VName, Ident))
 -> TryFusion [(VName, Ident)])
-> ((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, Type
t) -> do
    VName
outVar' <- String -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> TryFusion VName) -> String -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
outVar String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
    (VName, Ident) -> TryFusion (VName, Ident)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
outVar, VName -> Type -> Ident
Ident VName
outVar' Type
t)

  let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
        let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
            ([VName]
extra_nms, [Type]
extra_rtps) =
              [(VName, Type)] -> ([VName], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Type)] -> ([VName], [Type]))
-> [(VName, Type)] -> ([VName], [Type])
forall a b. (a -> b) -> a -> b
$
                ((VName, Type) -> Bool) -> [(VName, Type)] -> [(VName, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, Type) -> VName) -> (VName, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> VName
forall a b. (a, b) -> a
fst) ([(VName, Type)] -> [(VName, Type)])
-> [(VName, Type)] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
                  [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
soac_p
            res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [Type]
lambdaReturnType = Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
res_lam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
extra_rtps}
         in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)

  Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
horizFuse Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker)) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
    String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Horizontal fusion is invalid in the presence of output transforms."

  case (SOAC
soac_c, SOAC
soac_p) of
    (SOAC, SOAC)
_ | SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac_c -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC widths must match."
    ( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
      SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_
      )
        | [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall lore. [Scan lore] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall lore. [Reduce lore] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedKer
ker
            Bool -> Bool -> Bool
|| Bool
horizFuse -> do
          let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
              red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
              scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
              scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
              (Lambda SOACS
res_lam', [Input]
new_inp) =
                Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
                  Names
unfus_set
                  [VName]
outVars
                  Lambda SOACS
lam_p
                  [SubExp]
scan_nes_p
                  [SubExp]
red_nes_p
                  [Input]
inp_p_arr
                  [(VName, Ident)]
outPairs
                  Lambda SOACS
lam_c
                  [SubExp]
scan_nes_c
                  [SubExp]
red_nes_c
                  [Input]
inp_c_arr
              ([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
                Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
              ([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
                Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker
              unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
          [VName] -> SOAC -> TryFusion FusedKer
success
            ( [VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
            )
            (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma
              SubExp
w
              ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
              [Input]
new_inp

    ------------------
    -- Scatter fusion --
    ------------------

    -- Map-Scatter fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Scatter is not writing to any array used in
    -- the Map.
    ( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
      )
        | Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the scatter, i.e., not used elsewhere.
          Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
          let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
          [VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests

    -- Map-Hist fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Hist is not writing to any array used in
    -- the Map.
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
      )
        | Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the hist, i.e., not used elsewhere.
          Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
          let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
          [VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp

    -- Hist-Hist fusion
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
      SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_
      )
        | Bool
horizFuse -> do
          let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
              c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
              (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
              body' :: BodyT SOACS
body' =
                Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
                  { bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p, -- body_p and body_c have the same lores
                    bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
                    bodyResult :: [SubExp]
bodyResult =
                      Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
                        [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
                        [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
                        [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
                  }
              lam' :: Lambda SOACS
lam' =
                Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
                  { lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
                    lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
                    lambdaReturnType :: [Type]
lambdaReturnType =
                      Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
                        [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_c)
                        [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_p)
                  }
          [VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)

    -- Scatter-write fusion.
    ( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
      SOAC.Scatter SubExp
_len_p Lambda SOACS
_lam_p [Input]
ivs_p [(Shape, Int, VName)]
as_p
      )
        | Bool
horizFuse -> do
          let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
                where
                  ([a]
xs_indices, [a]
xs_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
                  ([a]
ys_indices, [a]
ys_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
          let (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
          let body' :: BodyT SOACS
body' =
                Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
                  { bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p, -- body_p and body_c have the same lores
                    bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
                    bodyResult :: [SubExp]
bodyResult = [(Shape, Int, VName)]
-> [SubExp] -> [(Shape, Int, VName)] -> [SubExp] -> [SubExp]
forall array a array.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c) [(Shape, Int, VName)]
as_p (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
                  }
          let lam' :: Lambda SOACS
lam' =
                Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
                  { lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
                    lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
                    lambdaReturnType :: [Type]
lambdaReturnType = [(Shape, Int, VName)]
-> [Type] -> [(Shape, Int, VName)] -> [Type] -> [Type]
forall array a array.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam_p)
                  }
          [VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c [(Shape, Int, VName)]
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
    (SOAC.Scatter {}, SOAC
_) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
    (SOAC
_, SOAC.Scatter {}) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
    ----------------------------
    -- Stream-Stream Fusions: --
    ----------------------------
    (SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
nes [Input]
_)
      | [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
        -- fuse two SEQUENTIAL streams
        ([VName]
res_nms, SOAC
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC
soac_c SOAC
soac_p
        [VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_nms SOAC
res_stream
    (SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two SEQ streams!"
    (SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream {}) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
    (SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
    (SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_)
      | [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
        -- fuse two PARALLEL streams
        ([VName]
res_nms, SOAC
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC
soac_c SOAC
soac_p
        [VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_nms SOAC
res_stream
    (SOAC.Stream {}, SOAC.Stream {}) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two PAR streams!"
    -------------------------------------------------------------------
    --- If one is a stream, translate the other to a stream as well.---
    --- This does not get in trouble (infinite computation) because ---
    ---   scan's translation to Stream introduces a hindrance to    ---
    ---   (horizontal fusion), hence repeated application is for the---
    ---   moment impossible. However, if with a dependence-graph rep---
    ---   we could run in an infinite recursion, i.e., repeatedly   ---
    ---   fusing map o scan into an infinity of Stream levels!      ---
    -------------------------------------------------------------------
    (SOAC.Stream SubExp
_ StreamForm SOACS
form2 Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC
_) -> do
      -- If this rule is matched then soac_p is NOT a stream.
      -- To fuse a stream kernel, we transform soac_p to a stream, which
      -- borrows the sequential/parallel property of the soac_c Stream,
      -- and recursively perform stream-stream fusion.
      (SOAC
soac_p', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC
soac_p
      SOAC
soac_p'' <- case StreamForm SOACS
form2 of
        Sequential {} -> SOAC -> TryFusion SOAC
toSeqStream SOAC
soac_p'
        StreamForm SOACS
_ -> SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
soac_p'
      if SOAC
soac_p' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC
soac_p
        then String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
        else Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC
soac_p'' Names
soac_p_consumed FusedKer
ker
    (SOAC
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) | Just [Scan SOACS]
_ <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form -> do
      -- A Scan soac can be currently only fused as a (sequential) stream,
      -- hence it is first translated to a (sequential) Stream and then
      -- fusion with a kernel is attempted.
      (SOAC
soac_p', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC
soac_p
      if SOAC
soac_p' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC
soac_p
        then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC
soac_p' Names
soac_p_consumed FusedKer
ker
        else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
    (SOAC
_, SOAC.Stream SubExp
_ StreamForm SOACS
form_p Lambda SOACS
_ [SubExp]
_ [Input]
_) -> do
      -- If it reached this case then soac_c is NOT a Stream kernel,
      -- hence transform the kernel's soac to a stream and attempt
      -- stream-stream fusion recursivelly.
      -- The newly created stream corresponding to soac_c borrows the
      -- sequential/parallel property of the soac_p stream.
      (SOAC
soac_c', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC
soac_c
      Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SOAC
soac_c' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC
soac_c) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
      SOAC
soac_c'' <- case StreamForm SOACS
form_p of
        StreamForm SOACS
Sequential -> SOAC -> TryFusion SOAC
toSeqStream SOAC
soac_c'
        StreamForm SOACS
_ -> SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
soac_c'

      Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC
soac_p Names
soac_p_consumed (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
        FusedKer
ker {fsoac :: SOAC
fsoac = SOAC
soac_c'', outNames :: [VName]
outNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedKer -> [VName]
outNames FusedKer
ker}

    ---------------------------------
    --- DEFAULT, CANNOT FUSE CASE ---
    ---------------------------------
    (SOAC, SOAC)
_ -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse"

getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder (Parallel StreamOrd
o Commutativity
_ Lambda lore
_) = StreamOrd
o
getStreamOrder StreamForm lore
Sequential = StreamOrd
InOrder

fuseStreamHelper ::
  [VName] ->
  Names ->
  [VName] ->
  [(VName, Ident)] ->
  SOAC ->
  SOAC ->
  TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper
  [VName]
out_kernms
  Names
unfus_set
  [VName]
outVars
  [(VName, Ident)]
outPairs
  (SOAC.Stream SubExp
w2 StreamForm SOACS
form2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
  (SOAC.Stream SubExp
_ StreamForm SOACS
form1 Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) =
    if StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form2 StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form1
      then String -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"fusion conditions not met!"
      else do
        -- very similar to redomap o redomap composition, but need
        -- to remove first the `chunk' parameters of streams'
        -- lambdas and put them in the resulting stream lambda.
        let chunk1 :: Param Type
chunk1 = [Param Type] -> Param Type
forall a. [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1
            chunk2 :: Param Type
chunk2 = [Param Type] -> Param Type
forall a. [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam2
            hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk2, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk1)]
            lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
            lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type] -> [Param Type]
forall a. [a] -> [a]
tail ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1}
            lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type] -> [Param Type]
forall a. [a] -> [a]
tail ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam20}
            (Lambda SOACS
res_lam', [Input]
new_inp) =
              Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
                Names
unfus_set
                [VName]
outVars
                Lambda SOACS
lam1'
                []
                [SubExp]
nes1
                [Input]
inp1_arr
                [(VName, Ident)]
outPairs
                Lambda SOACS
lam2'
                []
                [SubExp]
nes2
                [Input]
inp2_arr
            res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param Type
chunk1 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
res_lam'}
            unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
            unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
        StreamForm SOACS
res_form <- StreamForm SOACS
-> StreamForm SOACS -> TryFusion (StreamForm SOACS)
forall (m :: * -> *) lore.
MonadFail m =>
StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm SOACS
form2 StreamForm SOACS
form1
        ([VName], SOAC) -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
            SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w2 StreamForm SOACS
res_form Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
          )
    where
      mergeForms :: StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm lore
Sequential StreamForm lore
Sequential = StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return StreamForm lore
forall lore. StreamForm lore
Sequential
      mergeForms (Parallel StreamOrd
_ Commutativity
comm2 Lambda lore
lam2r) (Parallel StreamOrd
o1 Commutativity
comm1 Lambda lore
lam1r) =
        StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm lore -> m (StreamForm lore))
-> StreamForm lore -> m (StreamForm lore)
forall a b. (a -> b) -> a -> b
$ StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o1 (Commutativity
comm1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> Commutativity
comm2) (Lambda lore -> Lambda lore -> Lambda lore
forall lore. Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps Lambda lore
lam1r Lambda lore
lam2r)
      mergeForms StreamForm lore
_ StreamForm lore
_ = String -> m (StreamForm lore)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusing sequential to parallel stream disallowed!"
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC
_ SOAC
_ = String -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot Fuse Streams!"

-- | If a Stream is passed as argument then it converts it to a
--   Sequential Stream; Otherwise it FAILS!
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream s :: SOAC
s@(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) = SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
s
toSeqStream (SOAC.Stream SubExp
w Parallel {} Lambda SOACS
l [SubExp]
acc [Input]
inps) =
  SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> TryFusion SOAC) -> SOAC -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w StreamForm SOACS
forall lore. StreamForm lore
Sequential Lambda SOACS
l [SubExp]
acc [Input]
inps
toSeqStream SOAC
_ = String -> TryFusion SOAC
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"toSeqStream expects a stream, but given a SOAC."

-- Here follows optimizations and transforms to expose fusability.

optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel Maybe [VName]
inp FusedKer
ker = do
  (SOAC
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
startTrans
  FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
    FusedKer
ker
      { fsoac :: SOAC
fsoac = SOAC
soac,
        outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
resTrans
      }
  where
    startTrans :: ArrayTransforms
startTrans = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker

optimizeSOAC ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC
soac ArrayTransforms
os = do
  (Bool, SOAC, ArrayTransforms)
res <- ((Bool, SOAC, ArrayTransforms)
 -> (Maybe [VName]
     -> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
 -> TryFusion (Bool, SOAC, ArrayTransforms))
-> (Bool, SOAC, ArrayTransforms)
-> [Maybe [VName]
    -> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms)
comb (Bool
False, SOAC
soac, ArrayTransforms
os) [Maybe [VName]
 -> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
optimizations
  case (Bool, SOAC, ArrayTransforms)
res of
    (Bool
False, SOAC
_, ArrayTransforms
_) -> String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No optimisation applied"
    (Bool
True, SOAC
soac', ArrayTransforms
os') -> (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
os')
  where
    comb :: (Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms)
comb (Bool
changed, SOAC
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
f =
      do
        (SOAC
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
f Maybe [VName]
inp SOAC
soac' ArrayTransforms
os
        (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, SOAC
soac'', ArrayTransforms
os'')
        TryFusion (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
changed, SOAC
soac', ArrayTransforms
os')

type Optimization =
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)

optimizations :: [Optimization]
optimizations :: [Maybe [VName]
 -> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
iswim]

iswim ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
  | Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form,
    Just (Pattern
map_pat, Certificates
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pattern, Certificates, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
    Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
    let nes_idents :: [Ident]
nes_idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
nes_names ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
scan_fun
        map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
        map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
        ([Param Type]
scan_acc_params, [Param Type]
scan_elem_params) =
          Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
scan_fun
        map_params :: [Param Type]
map_params =
          (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param Type]
scan_acc_params
            [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
scan_elem_params
        map_rettype :: [Type]
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
scan_fun

        scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
map_fun
        scan_body :: BodyT SOACS
scan_body = Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
map_fun
        scan_rettype :: [Type]
scan_rettype = Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
map_fun
        scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
scan_params BodyT SOACS
scan_body [Type]
scan_rettype
        nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
        arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params

    ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']

    let map_body :: BodyT SOACS
map_body =
          Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody
            ( Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
                Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w Pattern
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                  Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
            )
            ([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
map_pat
        map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam SOACS]
map_params BodyT SOACS
map_body [Type]
map_rettype
        perm :: [Int]
perm = case Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
map_fun of
          [] -> []
          Type
t : [Type]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t]

    (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
        ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
map_cs [Int]
perm
      )
iswim Maybe [VName]
_ SOAC
_ ArrayTransforms
_ =
  String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ISWIM does not apply."

removeParamOuterDim :: LParam -> LParam
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
  let t :: Type
t = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
   in Param Type
LParam SOACS
param {paramDec :: Type
paramDec = Type
t}

setParamOuterDimTo :: SubExp -> LParam -> LParam
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
  let t :: Type
t = Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
   in Param Type
LParam SOACS
param {paramDec :: Type
paramDec = Type
t}

setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w = (Type -> Type) -> PatternT Type -> PatternT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)

-- Now for fiddling with transpositions...

commonTransforms ::
  [VName] ->
  [SOAC.Input] ->
  (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
  where
    inps' :: [(Bool, Input)]
inps' =
      [ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
        | Input
inp <- [Input]
inps
      ]

commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
  case ((Maybe ArrayTransform, [(Bool, Input)])
 -> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
    Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
    Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
  where
    inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
      case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
        (Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
          | ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
    inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)

mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
  Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  where
    resDims :: Int
resDims = [Type] -> Int
forall shape u. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
      [] -> Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
lam
      Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [Type]
forall lore. Nesting lore -> [Type]
MapNest.nestingReturnType Nesting SOACS
nest
    minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
    minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts

pullRearrange ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange SOAC
soac ArrayTransforms
ots = do
  MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
 Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC
soac
  SOAC.Rearrange Certificates
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall (m :: * -> *) a. Monad m => a -> m a
return (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
  if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let -- Expand perm to cover the full extent of the input dimensionality
          perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
            where
              r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
          addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
          inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
      SOAC
soac' <-
        MapNest -> TryFusion SOAC
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
 Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion SOAC) -> MapNest -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$
          [Input]
inputs' [Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
ots')
    else String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull transpose"

pushRearrange ::
  [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC
soac ArrayTransforms
ots = do
  MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
 Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC
soac
  ([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
  if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let invertRearrange :: ArrayTransform
invertRearrange = Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
      SOAC
soac' <-
        MapNest -> TryFusion SOAC
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
 Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion SOAC) -> MapNest -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$
          [Input]
inputs'
            [Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
    else String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot push transpose"

-- | Actually also rearranges indices.
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
  SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall lore.
SubExp -> Lambda lore -> [Nesting lore] -> [Input] -> MapNest lore
MapNest.MapNest
    SubExp
w
    Lambda SOACS
body
    ( (Nesting SOACS -> [Type] -> Nesting SOACS)
-> [Nesting SOACS] -> [[Type]] -> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
        Nesting SOACS -> [Type] -> Nesting SOACS
forall lore lore. Nesting lore -> [Type] -> Nesting lore
setReturnType
        [Nesting SOACS]
nestings
        ([[Type]] -> [Nesting SOACS]) -> [[Type]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int -> [[Type]] -> [[Type]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[Type]] -> [[Type]]) -> [[Type]] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ ([Type] -> [Type]) -> [Type] -> [[Type]]
forall a. (a -> a) -> a -> [a]
iterate ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType) [Type]
ts
    )
    [Input]
inps
  where
    origts :: [Type]
origts = MapNest -> [Type]
forall lore. MapNest lore -> [Type]
MapNest.typeOf MapNest
nest
    -- The permutation may be deeper than the rank of the type,
    -- but it is required that it is an identity permutation
    -- beyond that.  This is supposed to be checked as an
    -- invariant by whoever calls rearrangeReturnTypes.
    rearrangeType' :: Type -> Type
rearrangeType' Type
t = [Int] -> Type -> Type
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) [Int]
perm) Type
t
    ts :: [Type]
ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
rearrangeType' [Type]
origts

    setReturnType :: Nesting lore -> [Type] -> Nesting lore
setReturnType Nesting lore
nesting [Type]
t' =
      Nesting lore
nesting {nestingReturnType :: [Type]
MapNest.nestingReturnType = [Type]
t'}

fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
  case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
    [Int]
perm : [[Int]]
_ -> do
      [Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
      ([Int], [Input]) -> Maybe ([Int], [Input])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [Input]
inps')
    [[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
  where
    exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray

    inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ Type
_)
      | ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certificates
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing

    fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
      | Int
r <- Input -> Int
SOAC.inputRank Input
inp,
        Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
        Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
      | Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing

pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
  | Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
Futhark.isMapSOAC ScremaForm SOACS
form,
    SOAC.Reshape Certificates
cs ShapeChange SubExp
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
maplam = do
    let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape of
          [] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
          SubExp
d : [SubExp]
_ -> SubExp
d
        inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Input -> Input
SOAC.addTransform (ArrayTransform -> Input -> Input)
-> ArrayTransform -> Input -> Input
forall a b. (a -> b) -> a -> b
$ Certificates -> ShapeChange SubExp -> ArrayTransform
SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) [Input]
inps
        inputTypes :: [Type]
inputTypes = (Input -> Type) -> [Input] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
SOAC.inputType [Input]
inputs'

    let outersoac ::
          ([SOAC.Input] -> SOAC) ->
          (SubExp, [SubExp]) ->
          TryFusion ([SOAC.Input] -> SOAC)
        outersoac :: ([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC)
outersoac [Input] -> SOAC
inner (SubExp
w, [SubExp]
outershape) = do
          let addDims :: Type -> Type
addDims Type
t = Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
              retTypes :: [Type]
retTypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
addDims ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda SOACS
maplam

          [Param Type]
ps <- [Type]
-> (Type -> TryFusion (Param Type)) -> TryFusion [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
inputTypes ((Type -> TryFusion (Param Type)) -> TryFusion [Param Type])
-> (Type -> TryFusion (Param Type)) -> TryFusion [Param Type]
forall a b. (a -> b) -> a -> b
$ \Type
inpt ->
            String -> Type -> TryFusion (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"pullReshape_param" (Type -> TryFusion (Param Type)) -> Type -> TryFusion (Param Type)
forall a b. (a -> b) -> a -> b
$
              Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) Type
inpt

          BodyT SOACS
inner_body <-
            Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
              [BinderT
   SOACS
   (State VNameSource)
   (Exp (Lore (BinderT SOACS (State VNameSource))))]
-> BinderT
     SOACS
     (State VNameSource)
     (Body (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
     SOACS
     (State VNameSource)
     (Exp (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m)) =>
SOAC (Lore m) -> m (Exp (Lore m))
SOAC.toExp (SOAC (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT
      SOACS
      (State VNameSource)
      (Exp (Lore (BinderT SOACS (State VNameSource)))))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
     SOACS
     (State VNameSource)
     (Exp (Lore (BinderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC
inner ([Input] -> SOAC) -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ (Param Type -> Input) -> [Param Type] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input) -> (Param Type -> Ident) -> Param Type -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param Type]
ps]
          let inner_fun :: Lambda SOACS
inner_fun =
                Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
                  { lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
ps,
                    lambdaReturnType :: [Type]
lambdaReturnType = [Type]
retTypes,
                    lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
inner_body
                  }
          ([Input] -> SOAC) -> TryFusion ([Input] -> SOAC)
forall (m :: * -> *) a. Monad m => a -> m a
return (([Input] -> SOAC) -> TryFusion ([Input] -> SOAC))
-> ([Input] -> SOAC) -> TryFusion ([Input] -> SOAC)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC)
-> ScremaForm SOACS -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
inner_fun

    [Input] -> SOAC
op' <-
      (([Input] -> SOAC)
 -> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC))
-> ([Input] -> SOAC)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC)
-> ScremaForm SOACS -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC)
forall a b. (a -> b) -> a -> b
$
        [SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
          Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape
    (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Input] -> SOAC
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC
_ ArrayTransforms
_ = String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull reshape"

-- Tie it all together in exposeInputs (for making inputs to a
-- consumer available) and pullOutputTransforms (for moving
-- output-transforms of a producer to its inputs instead).

exposeInputs ::
  [VName] ->
  FusedKer ->
  TryFusion (FusedKer, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
inpIds FusedKer
ker =
  (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pushRearrange')
    TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pullRearrange')
    TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker
  where
    ot :: ArrayTransforms
ot = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker

    pushRearrange' :: TryFusion FusedKer
pushRearrange' = do
      (SOAC
soac', ArrayTransforms
ot') <- [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
ot
      FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
        FusedKer
ker
          { fsoac :: SOAC
fsoac = SOAC
soac',
            outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
ot'
          }

    pullRearrange' :: TryFusion FusedKer
pullRearrange' = do
      (SOAC
soac', ArrayTransforms
ot') <- SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
ot
      Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
        String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"pullRearrange was not enough"
      FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
        FusedKer
ker
          { fsoac :: SOAC
fsoac = SOAC
soac',
            outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms
          }

    exposeInputs' :: FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker' =
      case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [Input]
inputs FusedKer
ker' of
        (ArrayTransforms
ot', [Input]
inps')
          | (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
            (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer
ker' {fsoac :: SOAC
fsoac = [Input]
inps' [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC
fsoac FusedKer
ker'}, ArrayTransforms
ot')
        (ArrayTransforms, [Input])
_ -> String -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot expose"

    exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ Type
_)
      | ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
    exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds

outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
outputTransformPullers = [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange, SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullReshape]

pullOutputTransforms ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms = [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
forall t t.
[t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
outputTransformPullers
  where
    attempt :: [t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [] t
_ t
_ = String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull anything"
    attempt (t -> t -> TryFusion (SOAC, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC, ArrayTransforms)]
ps) t
soac t
ots =
      do
        (SOAC
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC, ArrayTransforms)
p t
soac t
ots
        if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
          then (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
SOAC.noTransforms)
          else SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms SOAC
soac' ArrayTransforms
ots' TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
ots')
        TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC, ArrayTransforms)]
ps t
soac t
ots