{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
  ( lowerUpdateKernels
  , lowerUpdate
  , LowerUpdate
  , DesiredUpdate (..)
  ) where

import Control.Monad
import Control.Monad.Writer
import Data.List (find, unzip4)
import Data.Maybe (mapMaybe)
import Data.Either
import qualified Data.Map as M

import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop.Aliases
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.Construct
import Futhark.Optimise.InPlaceLowering.SubstituteIndices

data DesiredUpdate dec =
  DesiredUpdate { DesiredUpdate dec -> VName
updateName :: VName -- ^ Name of result.
                , DesiredUpdate dec -> dec
updateType :: dec -- ^ Type of result.
                , DesiredUpdate dec -> Certificates
updateCertificates :: Certificates
                , DesiredUpdate dec -> VName
updateSource :: VName
                , DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp
                , DesiredUpdate dec -> VName
updateValue :: VName
                }
  deriving (Int -> DesiredUpdate dec -> ShowS
[DesiredUpdate dec] -> ShowS
DesiredUpdate dec -> String
(Int -> DesiredUpdate dec -> ShowS)
-> (DesiredUpdate dec -> String)
-> ([DesiredUpdate dec] -> ShowS)
-> Show (DesiredUpdate dec)
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate dec] -> ShowS
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
show :: DesiredUpdate dec -> String
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
Show)

instance Functor DesiredUpdate where
  a -> b
f fmap :: (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u { updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u }

updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
nameVName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> (DesiredUpdate dec -> VName) -> DesiredUpdate dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate dec -> VName
forall dec. DesiredUpdate dec -> VName
updateValue

type LowerUpdate lore m = Scope (Aliases lore)
                          -> Stm (Aliases lore)
                          -> [DesiredUpdate (LetDec (Aliases lore))]
                          -> Maybe (m [Stm (Aliases lore)])

lowerUpdate :: (MonadFreshNames m, Bindable lore,
                LetDec lore ~ Type, CanBeAliased (Op lore)) => LowerUpdate lore m
lowerUpdate :: LowerUpdate lore m
lowerUpdate Scope (Aliases lore)
scope (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
aux (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body)) [DesiredUpdate (LetDec (Aliases lore))]
updates = do
  m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
   BodyT (Aliases lore))
canDo <- Scope (Aliases lore)
-> [DesiredUpdate (LetDec (Aliases lore))]
-> Pattern (Aliases lore)
-> [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Maybe
     (m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
         [(FParam (Aliases lore), SubExp)],
         [(FParam (Aliases lore), SubExp)], BodyT (Aliases lore)))
forall lore als (m :: * -> *).
(Bindable lore, BinderOps lore, Aliased lore,
 LetDec lore ~ (als, Type), MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetDec lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope (Aliases lore)
scope [DesiredUpdate (LetDec (Aliases lore))]
updates Pattern (Aliases lore)
pat [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body
  m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$ do
    ([Stm (Aliases lore)]
prebnds, [Stm (Aliases lore)]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
ctx', [(Param DeclType, SubExp)]
val', BodyT (Aliases lore)
body') <- m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
   BodyT (Aliases lore))
canDo
    [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> m [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
      [Stm (Aliases lore)]
prebnds [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpDec lore)
StmAux (ExpDec (Aliases lore))
aux) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
                  [Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctxpat [Ident]
valpat (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$ [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> ExpT (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx' [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val' LoopForm (Aliases lore)
form BodyT (Aliases lore)
body'] [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
postbnds
lowerUpdate Scope (Aliases lore)
_
  (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
aux (BasicOp (SubExp (Var VName
v))))
  [DesiredUpdate VName
bindee_nm LetDec (Aliases lore)
bindee_dec Certificates
cs VName
src Slice SubExp
is VName
val]
  | PatternT (ConsumedInExp, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (ConsumedInExp, Type)
Pattern (Aliases lore)
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
    let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetDec (Aliases lore)
bindee_dec) Slice SubExp
is
    in m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$
       [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return [Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpDec lore)
StmAux (ExpDec (Aliases lore))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
               [Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetDec (Aliases lore)
bindee_dec] (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
               BasicOp -> ExpT (Aliases lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Aliases lore)) -> BasicOp -> ExpT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
v Slice SubExp
is' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val]
lowerUpdate Scope (Aliases lore)
_ Stm (Aliases lore)
_ [DesiredUpdate (LetDec (Aliases lore))]
_ =
  Maybe (m [Stm (Aliases lore)])
forall a. Maybe a
Nothing

lowerUpdateKernels :: MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels :: LowerUpdate Kernels m
lowerUpdateKernels Scope (Aliases Kernels)
scope
  (Let Pattern (Aliases Kernels)
pat StmAux (ExpDec (Aliases Kernels))
aux (Op (SegOp (SegMap lvl space ts kbody)))) [DesiredUpdate (LetDec (Aliases Kernels))]
updates
  | (DesiredUpdate (ConsumedInExp, Type) -> Bool)
-> [DesiredUpdate (ConsumedInExp, Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` PatternT (ConsumedInExp, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (ConsumedInExp, Type)
Pattern (Aliases Kernels)
pat) (VName -> Bool)
-> (DesiredUpdate (ConsumedInExp, Type) -> VName)
-> DesiredUpdate (ConsumedInExp, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (ConsumedInExp, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (ConsumedInExp, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates = do
      m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
   Stms (Aliases Kernels))
mk <- Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
     (m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
     (m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap Scope (Aliases Kernels)
scope Pattern (Aliases Kernels)
pat [DesiredUpdate (LetDec (Aliases Kernels))]
updates SegSpace
space KernelBody (Aliases Kernels)
kbody
      m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a. a -> Maybe a
Just (m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)]))
-> m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a b. (a -> b) -> a -> b
$ do
        (PatternT (ConsumedInExp, Type)
pat', KernelBody (Aliases Kernels)
kbody', Stms (Aliases Kernels)
poststms) <- m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
   Stms (Aliases Kernels))
mk
        let cs :: Certificates
cs = StmAux (ConsumedInExp, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ())
StmAux (ExpDec (Aliases Kernels))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> (DesiredUpdate (ConsumedInExp, Type) -> Certificates)
-> [DesiredUpdate (ConsumedInExp, Type)] -> Certificates
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DesiredUpdate (ConsumedInExp, Type) -> Certificates
forall dec. DesiredUpdate dec -> Certificates
updateCertificates [DesiredUpdate (ConsumedInExp, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates
        [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)])
-> [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall a b. (a -> b) -> a -> b
$
          Certificates -> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Pattern (Aliases Kernels)
-> StmAux (ExpDec (Aliases Kernels))
-> ExpT (Aliases Kernels)
-> Stm (Aliases Kernels)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (ConsumedInExp, Type)
Pattern (Aliases Kernels)
pat' StmAux (ExpDec (Aliases Kernels))
aux (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Aliases Kernels) -> ExpT (Aliases Kernels))
-> Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases Kernels)
 -> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases Kernels)
-> SegOp SegLevel (Aliases Kernels)
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases Kernels)
kbody') Stm (Aliases Kernels)
-> [Stm (Aliases Kernels)] -> [Stm (Aliases Kernels)]
forall a. a -> [a] -> [a]
:
          Stms (Aliases Kernels) -> [Stm (Aliases Kernels)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases Kernels)
poststms
lowerUpdateKernels Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetDec (Aliases Kernels))]
updates = LowerUpdate Kernels m
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetDec lore ~ Type,
 CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetDec (Aliases Kernels))]
updates

lowerUpdatesIntoSegMap :: MonadFreshNames m =>
                          Scope (Aliases Kernels)
                       -> Pattern (Aliases Kernels)
                       -> [DesiredUpdate (LetDec (Aliases Kernels))]
                       -> SegSpace
                       -> KernelBody (Aliases Kernels)
                       -> Maybe (m (Pattern (Aliases Kernels),
                                    KernelBody (Aliases Kernels),
                                    Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap :: Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
     (m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap Scope (Aliases Kernels)
scope Pattern (Aliases Kernels)
pat [DesiredUpdate (LetDec (Aliases Kernels))]
updates SegSpace
kspace KernelBody (Aliases Kernels)
kbody = do
  -- The updates are all-or-nothing.  Being more liberal would require
  -- changes to the in-place-lowering pass itself.
  [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))]
mk <- (PatElemT (ConsumedInExp, Type)
 -> KernelResult
 -> Maybe
      (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
          KernelResult, Stms (Aliases Kernels))))
-> [PatElemT (ConsumedInExp, Type)]
-> [KernelResult]
-> Maybe
     [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT (ConsumedInExp, Type)
-> KernelResult
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
onRet (PatternT (ConsumedInExp, Type) -> [PatElemT (ConsumedInExp, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (ConsumedInExp, Type)
Pattern (Aliases Kernels)
pat) (KernelBody (Aliases Kernels) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases Kernels)
kbody)
  m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
   Stms (Aliases Kernels))
-> Maybe
     (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
forall (m :: * -> *) a. Monad m => a -> m a
return (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
    Stms (Aliases Kernels))
 -> Maybe
      (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
          Stms (Aliases Kernels))))
-> m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
      Stms (Aliases Kernels))
-> Maybe
     (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
    ([PatElemT (ConsumedInExp, Type)]
pes, [Stms (Aliases Kernels)]
bodystms, [KernelResult]
krets, [Stms (Aliases Kernels)]
poststms) <- [(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
  KernelResult, Stms (Aliases Kernels))]
-> ([PatElemT (ConsumedInExp, Type)], [Stms (Aliases Kernels)],
    [KernelResult], [Stms (Aliases Kernels)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
   KernelResult, Stms (Aliases Kernels))]
 -> ([PatElemT (ConsumedInExp, Type)], [Stms (Aliases Kernels)],
     [KernelResult], [Stms (Aliases Kernels)]))
-> m [(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
       KernelResult, Stms (Aliases Kernels))]
-> m ([PatElemT (ConsumedInExp, Type)], [Stms (Aliases Kernels)],
      [KernelResult], [Stms (Aliases Kernels)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))]
-> m [(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
       KernelResult, Stms (Aliases Kernels))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))]
mk
    (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
 Stms (Aliases Kernels))
-> m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
      Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (ConsumedInExp, Type)]
-> [PatElemT (ConsumedInExp, Type)]
-> PatternT (ConsumedInExp, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (ConsumedInExp, Type)]
pes,
            KernelBody (Aliases Kernels)
kbody { kernelBodyStms :: Stms (Aliases Kernels)
kernelBodyStms = KernelBody (Aliases Kernels) -> Stms (Aliases Kernels)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody (Aliases Kernels)
kbody Stms (Aliases Kernels)
-> Stms (Aliases Kernels) -> Stms (Aliases Kernels)
forall a. Semigroup a => a -> a -> a
<> [Stms (Aliases Kernels)] -> Stms (Aliases Kernels)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases Kernels)]
bodystms
                  , kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
                  },
            [Stms (Aliases Kernels)] -> Stms (Aliases Kernels)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases Kernels)]
poststms)

  where ([VName]
gtids, [SubExp]
_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace

        onRet :: PatElemT (ConsumedInExp, Type)
-> KernelResult
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
onRet (PatElem VName
v (ConsumedInExp, Type)
v_dec) KernelResult
ret
          | Just (DesiredUpdate VName
bindee_nm (ConsumedInExp, Type)
bindee_dec Certificates
_cs VName
src Slice SubExp
slice VName
_val) <-
              (DesiredUpdate (ConsumedInExp, Type) -> Bool)
-> [DesiredUpdate (ConsumedInExp, Type)]
-> Maybe (DesiredUpdate (ConsumedInExp, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> (DesiredUpdate (ConsumedInExp, Type) -> VName)
-> DesiredUpdate (ConsumedInExp, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (ConsumedInExp, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (ConsumedInExp, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates = do

              Returns ResultManifest
_ SubExp
se <- KernelResult -> Maybe KernelResult
forall a. a -> Maybe a
Just KernelResult
ret

              m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
   KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a. a -> Maybe a
Just (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))
 -> Maybe
      (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
          KernelResult, Stms (Aliases Kernels))))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
                let pexp :: SubExp -> PrimExp VName
pexp = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32
                ([SubExp]
slice', Stms (Aliases Kernels)
bodystms) <- (BinderT (Aliases Kernels) m [SubExp]
 -> Scope (Aliases Kernels) -> m ([SubExp], Stms (Aliases Kernels)))
-> Scope (Aliases Kernels)
-> BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels))
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT (Aliases Kernels) m [SubExp]
-> Scope (Aliases Kernels) -> m ([SubExp], Stms (Aliases Kernels))
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope (Aliases Kernels)
scope (BinderT (Aliases Kernels) m [SubExp]
 -> m ([SubExp], Stms (Aliases Kernels)))
-> BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$
                  (PrimExp VName -> BinderT (Aliases Kernels) m SubExp)
-> [PrimExp VName] -> BinderT (Aliases Kernels) m [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (String -> PrimExp VName -> BinderT (Aliases Kernels) m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") ([PrimExp VName] -> BinderT (Aliases Kernels) m [SubExp])
-> [PrimExp VName] -> BinderT (Aliases Kernels) m [SubExp]
forall a b. (a -> b) -> a -> b
$
                  Slice (PrimExp VName) -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (PrimExp VName))
-> Slice SubExp -> Slice (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> PrimExp VName)
-> DimIndex SubExp -> DimIndex (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> PrimExp VName
pexp) Slice SubExp
slice) ([PrimExp VName] -> [PrimExp VName])
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
                  (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> PrimExp VName
pexp (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids

                let res_dims :: [SubExp]
res_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall a b. (a, b) -> b
snd (ConsumedInExp, Type)
bindee_dec
                    ret' :: KernelResult
ret' = [SubExp] -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns [SubExp]
res_dims VName
src [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]

                (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
 KernelResult, Stms (Aliases Kernels))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> (ConsumedInExp, Type) -> PatElemT (ConsumedInExp, Type)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
bindee_nm (ConsumedInExp, Type)
bindee_dec,
                        Stms (Aliases Kernels)
bodystms,
                        KernelResult
ret',
                        Stm (Aliases Kernels) -> Stms (Aliases Kernels)
forall lore. Stm lore -> Stms lore
oneStm (Stm (Aliases Kernels) -> Stms (Aliases Kernels))
-> Stm (Aliases Kernels) -> Stms (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ [Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
v_dec] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
                        BasicOp -> ExpT (Aliases Kernels)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Aliases Kernels))
-> BasicOp -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice)

        onRet PatElemT (ConsumedInExp, Type)
pe KernelResult
ret =
          m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
   KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a. a -> Maybe a
Just (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))
 -> Maybe
      (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
          KernelResult, Stms (Aliases Kernels))))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
 KernelResult, Stms (Aliases Kernels))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (ConsumedInExp, Type)
pe, Stms (Aliases Kernels)
forall a. Monoid a => a
mempty, KernelResult
ret, Stms (Aliases Kernels)
forall a. Monoid a => a
mempty)

lowerUpdateIntoLoop :: (Bindable lore, BinderOps lore,
                        Aliased lore, LetDec lore ~ (als, Type),
                        MonadFreshNames m) =>
                       Scope lore
                    -> [DesiredUpdate (LetDec lore)]
                    -> Pattern lore
                    -> [(FParam lore, SubExp)]
                    -> [(FParam lore, SubExp)]
                    -> LoopForm lore
                    -> Body lore
                    -> Maybe (m ([Stm lore],
                                 [Stm lore],
                                 [Ident],
                                 [Ident],
                                 [(FParam lore, SubExp)],
                                 [(FParam lore, SubExp)],
                                 Body lore))
lowerUpdateIntoLoop :: Scope lore
-> [DesiredUpdate (LetDec lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope lore
scope [DesiredUpdate (LetDec lore)]
updates Pattern lore
pat [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form Body lore
body = do
  -- Algorithm:
  --
  --   0) Map each result of the loop body to a corresponding in-place
  --      update, if one exists.
  --
  --   1) Create new merge variables corresponding to the arrays being
  --      updated; extend the pattern and the @res@ list with these,
  --      and remove the parts of the result list that have a
  --      corresponding in-place update.
  --
  --      (The creation of the new merge variable identifiers is
  --      actually done at the same time as step (0)).
  --
  --   2) Create in-place updates at the end of the loop body.
  --
  --   3) Create index expressions that read back the values written
  --      in (2).  If the merge parameter corresponding to this value
  --      is unique, also @copy@ this value.
  --
  --   4) Update the result of the loop body to properly pass the new
  --      arrays and indexed elements to the next iteration of the
  --      loop.
  --
  -- We also check that the merge parameters we work with have
  -- loop-invariant shapes.

  -- Safety condition (8).
  [((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases Body lore
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`nameIn` Names
als

  m [LoopResultSummary (als, Type)]
mk_in_place_map <- [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (m :: * -> *) als.
MonadFreshNames m =>
[DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop [DesiredUpdate (als, Type)]
[DesiredUpdate (LetDec lore)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val

  m ([Stm lore], [Stm lore], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a. a -> Maybe a
Just (m ([Stm lore], [Stm lore], [Ident], [Ident],
    [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
 -> Maybe
      (m ([Stm lore], [Stm lore], [Ident], [Ident],
          [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
          Body lore)))
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
      [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a b. (a -> b) -> a -> b
$ do
    [LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
    ([(Param DeclType, SubExp)]
val',[Stm lore]
prebnds,[Stm lore]
postbnds) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) lore als.
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
    let ([Ident]
ctxpat,[Ident]
valpat) = [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
        idxsubsts :: IndexSubstitutions (als, Type)
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions (als, Type)
forall dec. [LoopResultSummary dec] -> IndexSubstitutions dec
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
    (IndexSubstitutions (als, Type)
idxsubsts', Stms lore
newbnds) <- IndexSubstitutions (als, Type)
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall (m :: * -> *) lore dec.
(MonadFreshNames m, BinderOps lore, Bindable lore, Aliased lore,
 LetDec lore ~ dec) =>
IndexSubstitutions dec
-> Stms lore -> m (IndexSubstitutions dec, Stms lore)
substituteIndices IndexSubstitutions (als, Type)
idxsubsts (Stms lore -> m (IndexSubstitutions (als, Type), Stms lore))
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body
    ([SubExp]
body_res, Stms lore
res_bnds) <- [LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetDec lore)]
in_place_map IndexSubstitutions (als, Type)
IndexSubstitutions (LetDec lore)
idxsubsts'
    let body' :: Body lore
body' = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
newbndsStms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
res_bnds) [SubExp]
body_res
    ([Stm lore], [Stm lore], [Ident], [Ident],
 [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
      [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore]
prebnds, [Stm lore]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
ctx, [(Param DeclType, SubExp)]
val', Body lore
body')
  where usedInBody :: Names
usedInBody = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expandAliases ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
        expandAliases :: VName -> Names
expandAliases VName
v = case VName -> Scope lore -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Scope lore
scope of
                            Just (LetName LetDec lore
dec) -> VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (als, Type) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf (als, Type)
LetDec lore
dec
                            Maybe (NameInfo lore)
_ -> VName -> Names
oneName VName
v
        resmap :: [(SubExp, Ident)]
resmap = [SubExp] -> [Ident] -> [(SubExp, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body) ([Ident] -> [(SubExp, Ident)]) -> [Ident] -> [(SubExp, Ident)]
forall a b. (a -> b) -> a -> b
$ PatternT (als, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT (als, Type)
Pattern lore
pat

        mkMerges :: (MonadFreshNames m, Bindable lore) =>
                    [LoopResultSummary (als, Type)]
                 -> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
        mkMerges :: [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
          (([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm lore]
prebnds, [Stm lore]
postbnds)) <-
            WriterT
  ([Stm lore], [Stm lore])
  m
  ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm lore], [Stm lore]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Stm lore], [Stm lore])
   m
   ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
 -> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
       ([Stm lore], [Stm lore])))
-> WriterT
     ([Stm lore], [Stm lore])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm lore], [Stm lore]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
 -> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
     ([Stm lore], [Stm lore])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
     ([Stm lore], [Stm lore])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
 -> WriterT
      ([Stm lore], [Stm lore])
      m
      (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
     ([Stm lore], [Stm lore])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LoopResultSummary (als, Type)
-> WriterT
     ([Stm lore], [Stm lore])
     m
     (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) lore lore a.
(MonadFreshNames m, MonadWriter ([Stm lore], [Stm lore]) m,
 Bindable lore, Bindable lore) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
          ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm lore]
prebnds, [Stm lore]
postbnds)

        mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
          | Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
            VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
            let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
                elmident :: Ident
elmident = VName -> Type -> Ident
Ident
                           (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update)
                           (Type
source_t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims (DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update))
            ([Stm lore], [Stm lore]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update
                   (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
                   (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                   (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary],
                  [[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
elmident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index
                   (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
                   (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)])
            Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right (VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param
                            VName
mergename
                            (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
                            VName -> SubExp
Var VName
source)
          | Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
 -> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary

        mkResAndPat :: [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
summaries =
          let ([Ident]
origpat,[Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (als, Type) -> Either Ident Ident)
-> [LoopResultSummary (als, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (als, Type) -> Either Ident Ident
forall a. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (als, Type)]
summaries
          in (PatternT (als, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternContextIdents PatternT (als, Type)
Pattern lore
pat,
              [Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat)

        mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
          | Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
              Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
          | Bool
otherwise =
              Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall dec. LoopResultSummary dec -> Ident
inPatternAs LoopResultSummary (a, Type)
summary)

summariseLoop :: MonadFreshNames m =>
                 [DesiredUpdate (als, Type)]
              -> Names
              -> [(SubExp, Ident)]
              -> [(Param DeclType, SubExp)]
              -> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
  [m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m (LoopResultSummary (als, Type))]
 -> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Ident)
 -> (Param DeclType, SubExp)
 -> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge
  where summariseLoopResult :: (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExp
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
          | Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
            if DesiredUpdate (als, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update VName -> Names -> Bool
`nameIn` Names
usedInBody
            then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
            else if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
 -> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
              VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
              LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return LoopResultSummary :: forall dec.
SubExp
-> Ident
-> (Param DeclType, SubExp)
-> Maybe (DesiredUpdate dec, VName, dec)
-> LoopResultSummary dec
LoopResultSummary { resultSubExp :: SubExp
resultSubExp = SubExp
se
                                       , inPatternAs :: Ident
inPatternAs = Ident
v
                                       , mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit)
                                       , relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate = (DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just (DesiredUpdate (als, Type)
update,
                                                               VName
lowered_array,
                                                               DesiredUpdate (als, Type) -> (als, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update)
                                       }
            else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
        summariseLoopResult (SubExp, Ident)
_ (Param DeclType, SubExp)
_ =
          Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing -- XXX: conservative; but this entire pass is going away.

        hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType

        merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge

        loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v)    = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
        loopInvariant Constant{} = Bool
True

data LoopResultSummary dec =
  LoopResultSummary { LoopResultSummary dec -> SubExp
resultSubExp :: SubExp
                    , LoopResultSummary dec -> Ident
inPatternAs :: Ident
                    , LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp)
                    , LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
                    }
  deriving (Int -> LoopResultSummary dec -> ShowS
[LoopResultSummary dec] -> ShowS
LoopResultSummary dec -> String
(Int -> LoopResultSummary dec -> ShowS)
-> (LoopResultSummary dec -> String)
-> ([LoopResultSummary dec] -> ShowS)
-> Show (LoopResultSummary dec)
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary dec] -> ShowS
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
show :: LoopResultSummary dec -> String
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
Show)

indexSubstitutions :: [LoopResultSummary dec]
                   -> IndexSubstitutions dec
indexSubstitutions :: [LoopResultSummary dec] -> IndexSubstitutions dec
indexSubstitutions = (LoopResultSummary dec
 -> Maybe (VName, (Certificates, VName, dec, Slice SubExp)))
-> [LoopResultSummary dec] -> IndexSubstitutions dec
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary dec
-> Maybe (VName, (Certificates, VName, dec, Slice SubExp))
forall c.
LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution
  where getSubstitution :: LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution LoopResultSummary c
res = do
          (DesiredUpdate VName
_ c
_ Certificates
cs VName
_ Slice SubExp
is VName
_, VName
nm, c
dec) <- LoopResultSummary c -> Maybe (DesiredUpdate c, VName, c)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary c
res
          let name :: VName
name = Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary c -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary c
res
          (VName, (Certificates, VName, c, Slice SubExp))
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
name, (Certificates
cs, VName
nm, c
dec, Slice SubExp
is))

manipulateResult :: (Bindable lore, MonadFreshNames m) =>
                    [LoopResultSummary (LetDec lore)]
                 -> IndexSubstitutions (LetDec lore)
                 -> m (Result, Stms lore)
manipulateResult :: [LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (LetDec lore)]
summaries IndexSubstitutions (LetDec lore)
substs = do
  let ([SubExp]
orig_ses,[SubExp]
updated_ses) = [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExp SubExp] -> ([SubExp], [SubExp]))
-> [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetDec lore) -> Either SubExp SubExp)
-> [LoopResultSummary (LetDec lore)] -> [Either SubExp SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetDec lore) -> Either SubExp SubExp
forall dec. LoopResultSummary dec -> Either SubExp SubExp
unchangedRes [LoopResultSummary (LetDec lore)]
summaries
  ([SubExp]
subst_ses, [Stm lore]
res_bnds) <- WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore]))
-> WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall a b. (a -> b) -> a -> b
$ (SubExp
 -> (VName, (Certificates, VName, LetDec lore, Slice SubExp))
 -> WriterT [Stm lore] m SubExp)
-> [SubExp]
-> IndexSubstitutions (LetDec lore)
-> WriterT [Stm lore] m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp
-> (VName, (Certificates, VName, LetDec lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp
forall (m :: * -> *) lore t.
(MonadFreshNames m, MonadWriter [Stm lore] m, Bindable lore,
 Typed t) =>
SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes [SubExp]
updated_ses IndexSubstitutions (LetDec lore)
substs
  ([SubExp], Stms lore) -> m ([SubExp], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp]
orig_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
subst_ses, [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
res_bnds)
  where
    unchangedRes :: LoopResultSummary dec -> Either SubExp SubExp
unchangedRes LoopResultSummary dec
summary =
      case LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
        Maybe (DesiredUpdate dec, VName, dec)
Nothing -> SubExp -> Either SubExp SubExp
forall a b. a -> Either a b
Left (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExp
forall dec. LoopResultSummary dec -> SubExp
resultSubExp LoopResultSummary dec
summary
        Just (DesiredUpdate dec, VName, dec)
_  -> SubExp -> Either SubExp SubExp
forall a b. b -> Either a b
Right (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExp
forall dec. LoopResultSummary dec -> SubExp
resultSubExp LoopResultSummary dec
summary
    substRes :: SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes (Var VName
res_v) (VName
subst_v, (Certificates
_, VName
nm, t
_, Slice SubExp
_))
      | VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
        SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
    substRes SubExp
res_se (VName
_, (Certificates
cs, VName
nm, t
dec, Slice SubExp
is)) = do
      Ident
v' <- ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++String
"_updated") (Ident -> m Ident) -> Ident -> m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
dec
      [Stm lore] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Stm lore -> Stm lore) -> Stm lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
v'] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
nm (Type -> Slice SubExp -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
dec) Slice SubExp
is) SubExp
res_se]
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'