{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | Interchanging scans with inner maps.
module Futhark.Pass.ExtractKernels.ISRWIM
       ( iswim
       , irwim
       , rwimPossible
       )
       where

import Control.Arrow (first)
import Control.Monad.State

import Futhark.MonadFreshNames
import Futhark.IR.SOACS
import Futhark.Tools

-- | Interchange Scan With Inner Map. Tries to turn a @scan(map)@ into a
-- @map(scan)
iswim :: (MonadBinder m, Lore m ~ SOACS) =>
         Pattern
      -> SubExp
      -> Lambda
      -> [(SubExp, VName)]
      -> Maybe (m ())
iswim :: Pattern -> SubExp -> Lambda -> [(SubExp, VName)] -> Maybe (m ())
iswim Pattern
res_pat SubExp
w Lambda
scan_fun [(SubExp, VName)]
scan_input
  | Just (Pattern
map_pat, Certificates
map_cs, SubExp
map_w, Lambda
map_fun) <- Lambda -> Maybe (Pattern, Certificates, SubExp, Lambda)
rwimPossible Lambda
scan_fun = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      let ([SubExp]
accs, [VName]
arrs) = [(SubExp, VName)] -> ([SubExp], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
scan_input
      [VName]
arrs' <- [VName] -> m [VName]
forall (m :: * -> *). MonadBinder m => [VName] -> m [VName]
transposedArrays [VName]
arrs
      [VName]
accs' <- (SubExp -> m VName) -> [SubExp] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"acc" (ExpT SOACS -> m VName)
-> (SubExp -> ExpT SOACS) -> SubExp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS)
-> (SubExp -> BasicOp) -> SubExp -> ExpT SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
accs

      let map_arrs' :: [VName]
map_arrs' = [VName]
accs' [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
arrs'
          ([Param Type]
scan_acc_params, [Param Type]
scan_elem_params) =
            Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
scan_fun
          map_params :: [Param Type]
map_params = (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param Type]
scan_acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++
                       (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
scan_elem_params
          map_rettype :: [Type]
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Type -> Type
setOuterDimTo SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
scan_fun

          scan_params :: [LParam SOACS]
scan_params = Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
map_fun
          scan_body :: BodyT SOACS
scan_body = Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
map_fun
          scan_rettype :: [Type]
scan_rettype = Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
map_fun
          scan_fun' :: Lambda
scan_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
scan_params BodyT SOACS
scan_body [Type]
scan_rettype
          scan_input' :: [(SubExp, VName)]
scan_input' = ((VName, VName) -> (SubExp, VName))
-> [(VName, VName)] -> [(SubExp, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> SubExp) -> (VName, VName) -> (SubExp, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first VName -> SubExp
Var) ([(VName, VName)] -> [(SubExp, VName)])
-> [(VName, VName)] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$
                        ([VName] -> [VName] -> [(VName, VName)])
-> ([VName], [VName]) -> [(VName, VName)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (([VName], [VName]) -> [(VName, VName)])
-> ([VName], [VName]) -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs') ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
          ([SubExp]
nes', [VName]
scan_arrs) = [(SubExp, VName)] -> ([SubExp], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
scan_input'

      ScremaForm SOACS
scan_soac <- [Scan SOACS] -> m (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Lambda -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda
scan_fun' [SubExp]
nes']
      let map_body :: BodyT SOACS
map_body = Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w Pattern
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                             Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w ScremaForm SOACS
scan_soac [VName]
scan_arrs) ([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$
                            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
map_pat
          map_fun' :: Lambda
map_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam SOACS]
map_params BodyT SOACS
map_body [Type]
map_rettype

      PatternT Type
res_pat' <- ([Ident] -> PatternT Type) -> m [Ident] -> m (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Ident] -> [Ident] -> PatternT Type
basicPattern []) (m [Ident] -> m (PatternT Type)) -> m [Ident] -> m (PatternT Type)
forall a b. (a -> b) -> a -> b
$
                  (Ident -> m Ident) -> [Ident] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((String -> String) -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
(String -> String) -> Ident -> m Ident
newIdent' (String -> String -> String
forall a. Semigroup a => a -> a -> a
<>String
"_transposed") (Ident -> m Ident) -> (Ident -> Ident) -> Ident -> m Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> Ident
transposeIdentType) ([Ident] -> m [Ident]) -> [Ident] -> m [Ident]
forall a b. (a -> b) -> a -> b
$
                  PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT Type
Pattern
res_pat

      Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> Stm (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern
res_pat' (Certificates -> Attrs -> () -> StmAux ()
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
map_cs Attrs
forall a. Monoid a => a
mempty ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
map_w
        (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
map_fun') [VName]
map_arrs'

      [(Ident, Ident)] -> ((Ident, Ident) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Ident] -> [Ident] -> [(Ident, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT Type
Pattern
res_pat)
                 (PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT Type
res_pat')) (((Ident, Ident) -> m ()) -> m ())
-> ((Ident, Ident) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Ident
to, Ident
from) -> do
        let perm :: [Int]
perm = [Int
1,Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2..Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> Type
identType Ident
from)Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
        Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> Stm (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident
to]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                     BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
from
  | Bool
otherwise = Maybe (m ())
forall a. Maybe a
Nothing

-- | Interchange Reduce With Inner Map. Tries to turn a @reduce(map)@ into a
-- @map(reduce)
irwim :: (MonadBinder m, Lore m ~ SOACS) =>
         Pattern
      -> SubExp
      -> Commutativity -> Lambda
      -> [(SubExp, VName)]
      -> Maybe (m ())
irwim :: Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
res_pat SubExp
w Commutativity
comm Lambda
red_fun [(SubExp, VName)]
red_input
  | Just (Pattern
map_pat, Certificates
map_cs, SubExp
map_w, Lambda
map_fun) <- Lambda -> Maybe (Pattern, Certificates, SubExp, Lambda)
rwimPossible Lambda
red_fun = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      let ([SubExp]
accs, [VName]
arrs) = [(SubExp, VName)] -> ([SubExp], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
red_input
      [VName]
arrs' <- [VName] -> m [VName]
forall (m :: * -> *). MonadBinder m => [VName] -> m [VName]
transposedArrays [VName]
arrs
      -- FIXME?  Can we reasonably assume that the accumulator is a
      -- replicate?  We also assume that it is non-empty.
      let indexAcc :: SubExp -> m SubExp
indexAcc (Var VName
v) = do
            Type
v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              Type -> Slice SubExp -> Slice SubExp
fullSlice Type
v_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0]
          indexAcc Constant{} =
            String -> m SubExp
forall a. HasCallStack => String -> a
error String
"irwim: array accumulator is a constant."
      [SubExp]
accs' <- (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> m SubExp
indexAcc [SubExp]
accs

      let ([Param Type]
_red_acc_params, [Param Type]
red_elem_params) =
            Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
red_fun
          map_rettype :: [Type]
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
red_fun
          map_params :: [Param Type]
map_params = (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
red_elem_params

          red_params :: [LParam SOACS]
red_params = Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
map_fun
          red_body :: BodyT SOACS
red_body = Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
map_fun
          red_rettype :: [Type]
red_rettype = Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
map_fun
          red_fun' :: Lambda
red_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
red_params BodyT SOACS
red_body [Type]
red_rettype
          red_input' :: [(SubExp, VName)]
red_input' = [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
accs' ([VName] -> [(SubExp, VName)]) -> [VName] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
          red_pat :: Pattern
red_pat = Pattern -> Pattern
stripPatternOuterDim Pattern
map_pat

      BodyT SOACS
map_body <-
        case Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
red_pat SubExp
w Commutativity
comm Lambda
red_fun' [(SubExp, VName)]
red_input' of
          Maybe (m ())
Nothing -> do
            ScremaForm SOACS
reduce_soac <- [Reduce SOACS] -> m (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Commutativity -> Lambda -> [SubExp] -> Reduce SOACS
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm Lambda
red_fun' ([SubExp] -> Reduce SOACS) -> [SubExp] -> Reduce SOACS
forall a b. (a -> b) -> a -> b
$ ((SubExp, VName) -> SubExp) -> [(SubExp, VName)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, VName) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, VName)]
red_input']
            BodyT SOACS -> m (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT SOACS -> m (BodyT SOACS)) -> BodyT SOACS -> m (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
red_pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                              Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w ScremaForm SOACS
reduce_soac ([VName] -> SOAC SOACS) -> [VName] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ ((SubExp, VName) -> VName) -> [(SubExp, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, VName) -> VName
forall a b. (a, b) -> b
snd [(SubExp, VName)]
red_input') ([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$
              (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
map_pat
          Just m ()
m -> Scope SOACS -> m (BodyT SOACS) -> m (BodyT SOACS)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
map_params) (m (BodyT SOACS) -> m (BodyT SOACS))
-> m (BodyT SOACS) -> m (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
            Stms SOACS
map_body_bnds <- m () -> m (Stms (Lore m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ m ()
m
            BodyT SOACS -> m (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT SOACS -> m (BodyT SOACS)) -> BodyT SOACS -> m (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms SOACS
map_body_bnds ([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
map_pat

      let map_fun' :: Lambda
map_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam SOACS]
map_params BodyT SOACS
map_body [Type]
map_rettype

      Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> Stm (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
res_pat (Certificates -> Attrs -> () -> StmAux ()
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
map_cs Attrs
forall a. Monoid a => a
mempty ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
        Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
map_w (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
map_fun') [VName]
arrs'
  | Bool
otherwise = Maybe (m ())
forall a. Maybe a
Nothing

-- | Does this reduce operator contain an inner map, and if so, what
-- does that map look like?
rwimPossible :: Lambda
             -> Maybe (Pattern, Certificates, SubExp, Lambda)
rwimPossible :: Lambda -> Maybe (Pattern, Certificates, SubExp, Lambda)
rwimPossible Lambda
fun
  | Body BodyDec SOACS
_ Stms SOACS
stms [SubExp]
res <- Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
fun,
    [Stm SOACS
bnd] <- Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms, -- Body has a single binding
    Pattern
map_pat <- Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
bnd,
    (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
map_pat) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
res, -- Returned verbatim
    Op (Screma map_w form map_arrs) <- Stm SOACS -> ExpT SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm SOACS
bnd,
    Just Lambda
map_fun <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
    (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
fun) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName]
map_arrs =
      (PatternT Type, Certificates, SubExp, Lambda)
-> Maybe (PatternT Type, Certificates, SubExp, Lambda)
forall a. a -> Maybe a
Just (PatternT Type
Pattern
map_pat, Stm SOACS -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm SOACS
bnd, SubExp
map_w, Lambda
map_fun)
  | Bool
otherwise =
      Maybe (Pattern, Certificates, SubExp, Lambda)
forall a. Maybe a
Nothing

transposedArrays :: MonadBinder m => [VName] -> m [VName]
transposedArrays :: [VName] -> m [VName]
transposedArrays [VName]
arrs = [VName] -> (VName -> m VName) -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
arrs ((VName -> m VName) -> m [VName])
-> (VName -> m VName) -> m [VName]
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
  Type
t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  let perm :: [Int]
perm = [Int
1,Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2..Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
tInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr

removeParamOuterDim :: LParam -> LParam
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
  let t :: Type
t = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
  in Param Type
LParam SOACS
param { paramDec :: Type
paramDec = Type
t }

setParamOuterDimTo :: SubExp -> LParam -> LParam
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
  let t :: Type
t = SubExp -> Type -> Type
setOuterDimTo SubExp
w (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
  in Param Type
LParam SOACS
param { paramDec :: Type
paramDec = Type
t }

setIdentOuterDimTo :: SubExp -> Ident -> Ident
setIdentOuterDimTo :: SubExp -> Ident -> Ident
setIdentOuterDimTo SubExp
w Ident
ident =
  let t :: Type
t = SubExp -> Type -> Type
setOuterDimTo SubExp
w (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
  in Ident
ident { identType :: Type
identType = Type
t }

setOuterDimTo :: SubExp -> Type -> Type
setOuterDimTo :: SubExp -> Type -> Type
setOuterDimTo SubExp
w Type
t =
  Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow (Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t) SubExp
w

setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w Pattern
pat =
  [Ident] -> [Ident] -> PatternT Type
basicPattern [] ([Ident] -> PatternT Type) -> [Ident] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ (Ident -> Ident) -> [Ident] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Ident -> Ident
setIdentOuterDimTo SubExp
w) ([Ident] -> [Ident]) -> [Ident] -> [Ident]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT Type
Pattern
pat

transposeIdentType :: Ident -> Ident
transposeIdentType :: Ident -> Ident
transposeIdentType Ident
ident =
  Ident
ident { identType :: Type
identType = Type -> Type
transposeType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident }

stripIdentOuterDim :: Ident -> Ident
stripIdentOuterDim :: Ident -> Ident
stripIdentOuterDim Ident
ident =
  Ident
ident { identType :: Type
identType = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident }

stripPatternOuterDim :: Pattern -> Pattern
stripPatternOuterDim :: Pattern -> Pattern
stripPatternOuterDim Pattern
pat =
  [Ident] -> [Ident] -> PatternT Type
basicPattern [] ([Ident] -> PatternT Type) -> [Ident] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ (Ident -> Ident) -> [Ident] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Ident
stripIdentOuterDim ([Ident] -> [Ident]) -> [Ident] -> [Ident]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT Type
Pattern
pat