{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | It is well known that fully parallel loops can always be
-- interchanged inwards with a sequential loop.  This module
-- implements that transformation.
--
-- This is also where we implement loop-switching (for branches),
-- which is semantically similar to interchange.
module Futhark.Pass.ExtractKernels.Interchange
       (
         SeqLoop (..)
       , interchangeLoops
       , Branch (..)
       , interchangeBranch
       ) where

import Control.Monad.RWS.Strict
import Data.Maybe
import qualified Data.Map as M
import Data.List (find)

import Futhark.Pass.ExtractKernels.Distribution
  (LoopNesting(..), KernelNest, kernelNestLoops)
import Futhark.Representation.SOACS
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Tools

-- | An encoding of a sequential do-loop with no existential context,
-- alongside its result pattern.
data SeqLoop = SeqLoop [Int] Pattern [(FParam, SubExp)] (LoopForm SOACS) Body

seqLoopStm :: SeqLoop -> Stm
seqLoopStm :: SeqLoop -> Stm
seqLoopStm (SeqLoop [Int]
_ Pattern
pat [(FParam, SubExp)]
merge LoopForm SOACS
form Body
body) =
  Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ [(FParam, SubExp)]
-> [(FParam, SubExp)] -> LoopForm SOACS -> Body -> Exp SOACS
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(FParam, SubExp)]
merge LoopForm SOACS
form Body
body

interchangeLoop :: (MonadBinder m, LocalScope SOACS m) =>
                   (VName -> Maybe VName) -> SeqLoop -> LoopNesting
                -> m SeqLoop
interchangeLoop :: (VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop
  VName -> Maybe VName
isMapParameter
  (SeqLoop [Int]
perm Pattern
loop_pat [(FParam, SubExp)]
merge LoopForm SOACS
form Body
body)
  (MapNesting Pattern Kernels
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    [(Param (TypeBase Shape Uniqueness), SubExp)]
merge_expanded <-
      Scope SOACS
-> m [(Param (TypeBase Shape Uniqueness), SubExp)]
-> m [(Param (TypeBase Shape Uniqueness), SubExp)]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope SOACS) -> [Param Type] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) (m [(Param (TypeBase Shape Uniqueness), SubExp)]
 -> m [(Param (TypeBase Shape Uniqueness), SubExp)])
-> m [(Param (TypeBase Shape Uniqueness), SubExp)]
-> m [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$
      ((Param (TypeBase Shape Uniqueness), SubExp)
 -> m (Param (TypeBase Shape Uniqueness), SubExp))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> m [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param (TypeBase Shape Uniqueness), SubExp)
-> m (Param (TypeBase Shape Uniqueness), SubExp)
expand [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam, SubExp)]
merge

    let loop_pat_expanded :: PatternT Type
loop_pat_expanded =
          [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> PatElemT Type)
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> PatElemT Type
expandPatElem ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
Pattern
loop_pat
        new_params :: [Param Type]
new_params = [ VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl TypeBase Shape Uniqueness
ptype
                     | (Param VName
pname TypeBase Shape Uniqueness
ptype, SubExp
_) <- [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam, SubExp)]
merge ]
        new_arrs :: [VName]
new_arrs = ((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness) -> VName
forall attr. Param attr -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
    -> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
merge_expanded
        rettype :: [Type]
rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
loop_pat_expanded

    -- If the map consumes something that is bound outside the loop
    -- (i.e. is not a merge parameter), we have to copy() it.  As a
    -- small simplification, we just remove the parameter outright if
    -- it is not used anymore.  This might happen if the parameter was
    -- used just as the inital value of a merge parameter.
    (([Param Type]
params', [VName]
arrs'), Stms SOACS
pre_copy_bnds) <-
      Binder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS ([Param Type], [VName])
 -> m (([Param Type], [VName]), Stms SOACS))
-> Binder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Scope SOACS
-> Binder SOACS ([Param Type], [VName])
-> Binder SOACS ([Param Type], [VName])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
new_params) (Binder SOACS ([Param Type], [VName])
 -> Binder SOACS ([Param Type], [VName]))
-> Binder SOACS ([Param Type], [VName])
-> Binder SOACS ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
      [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> ([Maybe (Param Type, VName)] -> [(Param Type, VName)])
-> [Maybe (Param Type, VName)]
-> ([Param Type], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Param Type, VName)] -> [(Param Type, VName)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (Param Type, VName)] -> ([Param Type], [VName]))
-> BinderT SOACS (State VNameSource) [Maybe (Param Type, VName)]
-> Binder SOACS ([Param Type], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName)))
-> [(Param Type, VName)]
-> BinderT SOACS (State VNameSource) [Maybe (Param Type, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam [(Param Type, VName)]
params_and_arrs

    Body
body' <- [Param Type] -> Body -> m Body
forall (t :: * -> *) (m :: * -> *) attr lore.
(Foldable t, MonadFreshNames m, Typed attr, BodyAttr lore ~ (),
 LetAttr lore ~ Type, ExpAttr lore ~ ()) =>
t (Param attr) -> BodyT lore -> m (BodyT lore)
mkDummyStms ([Param Type]
params'[Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<>[Param Type]
new_params) Body
body

    let lam :: LambdaT SOACS
lam = [LParam SOACS] -> Body -> [Type] -> LambdaT SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ([Param Type]
params'[Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<>[Param Type]
new_params) Body
body' [Type]
rettype
        map_bnd :: Stm
map_bnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern
loop_pat_expanded (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$
                  Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (LambdaT SOACS -> ScremaForm SOACS
forall lore. Bindable lore => Lambda lore -> ScremaForm lore
mapSOAC LambdaT SOACS
lam) ([VName] -> SOAC SOACS) -> [VName] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [VName]
arrs' [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
new_arrs
        res :: [SubExp]
res = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
loop_pat_expanded
        pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
Pattern Kernels
pat

    SeqLoop -> m SeqLoop
forall (m :: * -> *) a. Monad m => a -> m a
return (SeqLoop -> m SeqLoop) -> SeqLoop -> m SeqLoop
forall a b. (a -> b) -> a -> b
$
      [Int]
-> Pattern
-> [(FParam, SubExp)]
-> LoopForm SOACS
-> Body
-> SeqLoop
SeqLoop [Int]
perm PatternT Type
Pattern
pat' [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam, SubExp)]
merge_expanded LoopForm SOACS
form (Body -> SeqLoop) -> Body -> SeqLoop
forall a b. (a -> b) -> a -> b
$
      Stms SOACS -> [SubExp] -> Body
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms SOACS
pre_copy_bndsStms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<>Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm
map_bnd) [SubExp]
res
  where free_in_body :: Names
free_in_body = Body -> Names
forall a. FreeIn a => a -> Names
freeIn Body
body

        copyOrRemoveParam :: (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam (Param Type
param, VName
arr)
          | Bool -> Bool
not (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
param VName -> Names -> Bool
`nameIn` Names
free_in_body) =
            Maybe (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Param Type, VName)
forall a. Maybe a
Nothing
          | Bool
otherwise =
            Maybe (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Param Type, VName)
 -> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName)))
-> Maybe (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall a b. (a -> b) -> a -> b
$ (Param Type, VName) -> Maybe (Param Type, VName)
forall a. a -> Maybe a
Just (Param Type
param, VName
arr)

        expandedInit :: [Char] -> SubExp -> m SubExp
expandedInit [Char]
_ (Var VName
v)
          | Just VName
arr <- VName -> Maybe VName
isMapParameter VName
v =
              SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
        expandedInit [Char]
param_name SubExp
se =
          [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp ([Char]
param_name [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_expanded_init") (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp (Lore m)
forall lore. Shape -> SubExp -> BasicOp lore
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

        expand :: (Param (TypeBase Shape Uniqueness), SubExp)
-> m (Param (TypeBase Shape Uniqueness), SubExp)
expand (Param (TypeBase Shape Uniqueness)
merge_param, SubExp
merge_init) = do
          Param (TypeBase Shape Uniqueness)
expanded_param <-
            [Char]
-> TypeBase Shape Uniqueness
-> m (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) attr.
MonadFreshNames m =>
[Char] -> attr -> m (Param attr)
newParam ([Char]
param_name [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_expanded") (TypeBase Shape Uniqueness
 -> m (Param (TypeBase Shape Uniqueness)))
-> TypeBase Shape Uniqueness
-> m (Param (TypeBase Shape Uniqueness))
forall a b. (a -> b) -> a -> b
$
            TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType Param (TypeBase Shape Uniqueness)
merge_param) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> TypeBase Shape Uniqueness
forall a b. (a -> b) -> a -> b
$
            TypeBase Shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness (TypeBase Shape Uniqueness -> Uniqueness)
-> TypeBase Shape Uniqueness -> Uniqueness
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf Param (TypeBase Shape Uniqueness)
merge_param
          SubExp
expanded_init <- [Char] -> SubExp -> m SubExp
expandedInit [Char]
param_name SubExp
merge_init
          (Param (TypeBase Shape Uniqueness), SubExp)
-> m (Param (TypeBase Shape Uniqueness), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (TypeBase Shape Uniqueness)
expanded_param, SubExp
expanded_init)
            where param_name :: [Char]
param_name = VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> VName
forall attr. Param attr -> VName
paramName Param (TypeBase Shape Uniqueness)
merge_param

        expandPatElem :: PatElemT Type -> PatElemT Type
expandPatElem (PatElem VName
name Type
t) =
          VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
name (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
w

        -- | The kernel extractor cannot handle identity mappings, so
        -- insert dummy statements for body results that are just a
        -- lambda parameter.
        mkDummyStms :: t (Param attr) -> BodyT lore -> m (BodyT lore)
mkDummyStms t (Param attr)
params (Body () Stms lore
stms [SubExp]
res) = do
          ([SubExp]
res', [Stms lore]
extra_stms) <- [(SubExp, Stms lore)] -> ([SubExp], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, Stms lore)] -> ([SubExp], [Stms lore]))
-> m [(SubExp, Stms lore)] -> m ([SubExp], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m (SubExp, Stms lore))
-> [SubExp] -> m [(SubExp, Stms lore)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m (SubExp, Stms lore)
dummyStm [SubExp]
res
          BodyT lore -> m (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT lore -> m (BodyT lore)) -> BodyT lore -> m (BodyT lore)
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms lore
stmsStms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>[Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
extra_stms) [SubExp]
res'
          where dummyStm :: SubExp -> m (SubExp, Stms lore)
dummyStm (Var VName
v)
                  | Just Param attr
p <- (Param attr -> Bool) -> t (Param attr) -> Maybe (Param attr)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool) -> (Param attr -> VName) -> Param attr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param attr -> VName
forall attr. Param attr -> VName
paramName) t (Param attr)
params = do
                      VName
dummy <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_dummy")
                      (SubExp, Stms lore) -> m (SubExp, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
dummy,
                              Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$
                                Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
dummy (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Param attr -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param attr
p])
                                    (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                                     BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
p)
                dummyStm SubExp
se = (SubExp, Stms lore) -> m (SubExp, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
se, Stms lore
forall a. Monoid a => a
mempty)

-- | Given a (parallel) map nesting and an inner sequential loop, move
-- the maps inside the sequential loop.  The result is several
-- statements - one of these will be the loop, which will then contain
-- statements with 'Map' expressions.
interchangeLoops :: (MonadFreshNames m, HasScope SOACS m) =>
                    KernelNest -> SeqLoop
                 -> m (Stms SOACS)
interchangeLoops :: KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest SeqLoop
loop = do
  (SeqLoop
loop', Stms SOACS
bnds) <-
    Binder SOACS SeqLoop -> m (SeqLoop, Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS SeqLoop -> m (SeqLoop, Stms SOACS))
-> Binder SOACS SeqLoop -> m (SeqLoop, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (SeqLoop -> LoopNesting -> Binder SOACS SeqLoop)
-> SeqLoop -> [LoopNesting] -> Binder SOACS SeqLoop
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((VName -> Maybe VName)
-> SeqLoop -> LoopNesting -> Binder SOACS SeqLoop
forall (m :: * -> *).
(MonadBinder m, LocalScope SOACS m) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop VName -> Maybe VName
isMapParameter) SeqLoop
loop ([LoopNesting] -> Binder SOACS SeqLoop)
-> [LoopNesting] -> Binder SOACS SeqLoop
forall a b. (a -> b) -> a -> b
$
    [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest
  Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (SeqLoop -> Stm
seqLoopStm SeqLoop
loop')
  where isMapParameter :: VName -> Maybe VName
isMapParameter VName
v =
          ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd (Maybe (Param Type, VName) -> Maybe VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall a b. (a -> b) -> a -> b
$ ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> Maybe (Param Type, VName))
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall a b. (a -> b) -> a -> b
$
          (LoopNesting -> [(Param Type, VName)])
-> [LoopNesting] -> [(Param Type, VName)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs ([LoopNesting] -> [(Param Type, VName)])
-> [LoopNesting] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest

data Branch = Branch [Int] Pattern SubExp Body Body (IfAttr (BranchType SOACS))

branchStm :: Branch -> Stm
branchStm :: Branch -> Stm
branchStm (Branch [Int]
_ Pattern
pat SubExp
cond Body
tbranch Body
fbranch IfAttr (BranchType SOACS)
ret) =
  Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ SubExp -> Body -> Body -> IfAttr (BranchType SOACS) -> Exp SOACS
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond Body
tbranch Body
fbranch IfAttr (BranchType SOACS)
ret

interchangeBranch1 :: (MonadBinder m, LocalScope SOACS m) =>
                      Branch -> LoopNesting -> m Branch
interchangeBranch1 :: Branch -> LoopNesting -> m Branch
interchangeBranch1
  (Branch [Int]
perm Pattern
branch_pat SubExp
cond Body
tbranch Body
fbranch (IfAttr [BranchType SOACS]
ret IfSort
if_sort))
  (MapNesting Pattern Kernels
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    let ret' :: [TypeBase (ShapeBase (Ext SubExp)) NoUniqueness]
ret' = (TypeBase (ShapeBase (Ext SubExp)) NoUniqueness
 -> TypeBase (ShapeBase (Ext SubExp)) NoUniqueness)
-> [TypeBase (ShapeBase (Ext SubExp)) NoUniqueness]
-> [TypeBase (ShapeBase (Ext SubExp)) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase (ShapeBase (Ext SubExp)) NoUniqueness
-> Ext SubExp -> TypeBase (ShapeBase (Ext SubExp)) NoUniqueness
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
w) [TypeBase (ShapeBase (Ext SubExp)) NoUniqueness]
[BranchType SOACS]
ret
        pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
Pattern Kernels
pat

        ([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
        lam_ret :: [Type]
lam_ret = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern Kernels
pat

        branch_pat' :: PatternT Type
branch_pat' =
          [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> PatElemT Type)
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> PatElemT Type -> PatElemT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w)) ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
Pattern
branch_pat

        mkBranch :: Body -> m Body
mkBranch Body
branch = (Body -> m Body
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody(Body -> m Body) -> m Body -> m Body
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (m Body -> m Body) -> m Body -> m Body
forall a b. (a -> b) -> a -> b
$ do
          let bound_in_branch :: Scope SOACS
bound_in_branch = Stms SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf (Body -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms Body
branch)
          Body
branch' <-
            -- XXX: We may need dummys binding to prevent identity
            -- mappings.  The kernel extractor does not like identity
            -- mappings.
            Binder SOACS Body -> m Body
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS Body -> m Body) -> Binder SOACS Body -> m Body
forall a b. (a -> b) -> a -> b
$
            [SubExp] -> Body
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ([SubExp] -> Body)
-> BinderT SOACS (State VNameSource) [SubExp] -> Binder SOACS Body
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp -> BinderT SOACS (State VNameSource) SubExp)
-> [SubExp] -> BinderT SOACS (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope SOACS -> SubExp -> BinderT SOACS (State VNameSource) SubExp
forall (m :: * -> *) a.
MonadBinder m =>
Map VName a -> SubExp -> m SubExp
dummyBindIfNotIn Scope SOACS
bound_in_branch) ([SubExp] -> BinderT SOACS (State VNameSource) [SubExp])
-> BinderT SOACS (State VNameSource) [SubExp]
-> BinderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind Body (Lore (BinderT SOACS (State VNameSource)))
Body
branch)
          let lam :: LambdaT SOACS
lam = [LParam SOACS] -> Body -> [Type] -> LambdaT SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam SOACS]
params Body
branch' [Type]
lam_ret
              res :: [SubExp]
res = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
branch_pat'
              map_bnd :: Stm
map_bnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern
branch_pat' (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (LambdaT SOACS -> ScremaForm SOACS
forall lore. Bindable lore => Lambda lore -> ScremaForm lore
mapSOAC LambdaT SOACS
lam) [VName]
arrs
          Body -> m Body
forall (m :: * -> *) a. Monad m => a -> m a
return (Body -> m Body) -> Body -> m Body
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [SubExp] -> Body
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm
map_bnd) [SubExp]
res

    Body
tbranch' <- Body -> m Body
mkBranch Body
tbranch
    Body
fbranch' <- Body -> m Body
mkBranch Body
fbranch
    Branch -> m Branch
forall (m :: * -> *) a. Monad m => a -> m a
return (Branch -> m Branch) -> Branch -> m Branch
forall a b. (a -> b) -> a -> b
$ [Int]
-> Pattern
-> SubExp
-> Body
-> Body
-> IfAttr (BranchType SOACS)
-> Branch
Branch [Int
0..PatternT Type -> Int
forall attr. PatternT attr -> Int
patternSize PatternT Type
Pattern Kernels
patInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] PatternT Type
Pattern
pat' SubExp
cond Body
tbranch' Body
fbranch' (IfAttr (BranchType SOACS) -> Branch)
-> IfAttr (BranchType SOACS) -> Branch
forall a b. (a -> b) -> a -> b
$
      [TypeBase (ShapeBase (Ext SubExp)) NoUniqueness]
-> IfSort
-> IfAttr (TypeBase (ShapeBase (Ext SubExp)) NoUniqueness)
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [TypeBase (ShapeBase (Ext SubExp)) NoUniqueness]
ret' IfSort
if_sort
  where dummyBind :: SubExp -> m SubExp
dummyBind SubExp
se = do
          VName
dummy <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"dummy"
          [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
dummy] (BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore m)
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se)
          SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dummy

        dummyBindIfNotIn :: Map VName a -> SubExp -> m SubExp
dummyBindIfNotIn Map VName a
bound_in_branch SubExp
se
          | Var VName
v <- SubExp
se,
            VName
v VName -> Map VName a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map VName a
bound_in_branch = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
          | Bool
otherwise = SubExp -> m SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> m SubExp
dummyBind SubExp
se

interchangeBranch :: (MonadFreshNames m, HasScope SOACS m) =>
                     KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch :: KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest Branch
loop = do
  (Branch
loop', Stms SOACS
bnds) <-
    Binder SOACS Branch -> m (Branch, Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS Branch -> m (Branch, Stms SOACS))
-> Binder SOACS Branch -> m (Branch, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (Branch -> LoopNesting -> Binder SOACS Branch)
-> Branch -> [LoopNesting] -> Binder SOACS Branch
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Branch -> LoopNesting -> Binder SOACS Branch
forall (m :: * -> *).
(MonadBinder m, LocalScope SOACS m) =>
Branch -> LoopNesting -> m Branch
interchangeBranch1 Branch
loop ([LoopNesting] -> Binder SOACS Branch)
-> [LoopNesting] -> Binder SOACS Branch
forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest
  Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Branch -> Stm
branchStm Branch
loop')