{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.InPlaceLowering.LowerIntoStm
  ( lowerUpdateKernels,
    lowerUpdate,
    LowerUpdate,
    DesiredUpdate (..),
  )
where

import Control.Monad
import Control.Monad.Writer
import Data.Either
import Data.List (find, unzip4)
import Data.Maybe (isNothing, mapMaybe)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.Optimise.InPlaceLowering.SubstituteIndices

data DesiredUpdate dec = DesiredUpdate
  { -- | Name of result.
    DesiredUpdate dec -> VName
updateName :: VName,
    -- | Type of result.
    DesiredUpdate dec -> dec
updateType :: dec,
    DesiredUpdate dec -> Certificates
updateCertificates :: Certificates,
    DesiredUpdate dec -> VName
updateSource :: VName,
    DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp,
    DesiredUpdate dec -> VName
updateValue :: VName
  }
  deriving (Int -> DesiredUpdate dec -> ShowS
[DesiredUpdate dec] -> ShowS
DesiredUpdate dec -> String
(Int -> DesiredUpdate dec -> ShowS)
-> (DesiredUpdate dec -> String)
-> ([DesiredUpdate dec] -> ShowS)
-> Show (DesiredUpdate dec)
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate dec] -> ShowS
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
show :: DesiredUpdate dec -> String
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
Show)

instance Functor DesiredUpdate where
  a -> b
f fmap :: (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u {updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u}

updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> (DesiredUpdate dec -> VName) -> DesiredUpdate dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate dec -> VName
forall dec. DesiredUpdate dec -> VName
updateValue

type LowerUpdate lore m =
  Scope (Aliases lore) ->
  Stm (Aliases lore) ->
  [DesiredUpdate (LetDec (Aliases lore))] ->
  Maybe (m [Stm (Aliases lore)])

lowerUpdate ::
  ( MonadFreshNames m,
    Bindable lore,
    LetDec lore ~ Type,
    CanBeAliased (Op lore)
  ) =>
  LowerUpdate lore m
lowerUpdate :: LowerUpdate lore m
lowerUpdate Scope (Aliases lore)
scope (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
aux (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body)) [DesiredUpdate (LetDec (Aliases lore))]
updates = do
  m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
   BodyT (Aliases lore))
canDo <- Scope (Aliases lore)
-> [DesiredUpdate (LetDec (Aliases lore))]
-> Pattern (Aliases lore)
-> [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Maybe
     (m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
         [(FParam (Aliases lore), SubExp)],
         [(FParam (Aliases lore), SubExp)], BodyT (Aliases lore)))
forall lore als (m :: * -> *).
(Bindable lore, BinderOps lore, Aliased lore,
 LetDec lore ~ (als, Type), MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (LetDec lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope (Aliases lore)
scope [DesiredUpdate (LetDec (Aliases lore))]
updates Pattern (Aliases lore)
pat [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body
  m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$ do
    ([Stm (Aliases lore)]
prebnds, [Stm (Aliases lore)]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
ctx', [(Param DeclType, SubExp)]
val', BodyT (Aliases lore)
body') <- m ([Stm (Aliases lore)], [Stm (Aliases lore)], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
   BodyT (Aliases lore))
canDo
    [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> m [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
      [Stm (Aliases lore)]
prebnds
        [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [ Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpDec lore)
StmAux (ExpDec (Aliases lore))
aux) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
               [Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctxpat [Ident]
valpat (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$ [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> ExpT (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx' [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val' LoopForm (Aliases lore)
form BodyT (Aliases lore)
body'
           ]
        [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
postbnds
lowerUpdate
  Scope (Aliases lore)
_
  (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
aux (BasicOp (SubExp (Var VName
v))))
  [DesiredUpdate VName
bindee_nm LetDec (Aliases lore)
bindee_dec Certificates
cs VName
src Slice SubExp
is VName
val]
    | PatternT (ConsumedInExp, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (ConsumedInExp, Type)
Pattern (Aliases lore)
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
      let is' :: Slice SubExp
is' = Type -> Slice SubExp -> Slice SubExp
fullSlice ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetDec (Aliases lore)
bindee_dec) Slice SubExp
is
       in m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a. a -> Maybe a
Just (m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)]))
-> m [Stm (Aliases lore)] -> Maybe (m [Stm (Aliases lore)])
forall a b. (a -> b) -> a -> b
$
            [Stm (Aliases lore)] -> m [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return
              [ Certificates -> Stm (Aliases lore) -> Stm (Aliases lore)
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux (ConsumedInExp, ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ExpDec lore)
StmAux (ExpDec (Aliases lore))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) (Stm (Aliases lore) -> Stm (Aliases lore))
-> Stm (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
                  [Ident] -> [Ident] -> ExpT (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
LetDec (Aliases lore)
bindee_dec] (ExpT (Aliases lore) -> Stm (Aliases lore))
-> ExpT (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> ExpT (Aliases lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Aliases lore)) -> BasicOp -> ExpT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
v Slice SubExp
is' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val
              ]
lowerUpdate Scope (Aliases lore)
_ Stm (Aliases lore)
_ [DesiredUpdate (LetDec (Aliases lore))]
_ =
  Maybe (m [Stm (Aliases lore)])
forall a. Maybe a
Nothing

lowerUpdateKernels :: MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels :: LowerUpdate Kernels m
lowerUpdateKernels
  Scope (Aliases Kernels)
scope
  (Let Pattern (Aliases Kernels)
pat StmAux (ExpDec (Aliases Kernels))
aux (Op (SegOp (SegMap lvl space ts kbody))))
  [DesiredUpdate (LetDec (Aliases Kernels))]
updates
    | (DesiredUpdate (ConsumedInExp, Type) -> Bool)
-> [DesiredUpdate (ConsumedInExp, Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` PatternT (ConsumedInExp, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (ConsumedInExp, Type)
Pattern (Aliases Kernels)
pat) (VName -> Bool)
-> (DesiredUpdate (ConsumedInExp, Type) -> VName)
-> DesiredUpdate (ConsumedInExp, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (ConsumedInExp, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (ConsumedInExp, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates = do
      m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
   Stms (Aliases Kernels))
mk <- Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
     (m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
     (m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap Scope (Aliases Kernels)
scope Pattern (Aliases Kernels)
pat [DesiredUpdate (LetDec (Aliases Kernels))]
updates SegSpace
space KernelBody (Aliases Kernels)
kbody
      m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a. a -> Maybe a
Just (m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)]))
-> m [Stm (Aliases Kernels)] -> Maybe (m [Stm (Aliases Kernels)])
forall a b. (a -> b) -> a -> b
$ do
        (PatternT (ConsumedInExp, Type)
pat', KernelBody (Aliases Kernels)
kbody', Stms (Aliases Kernels)
poststms) <- m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
   Stms (Aliases Kernels))
mk
        let cs :: Certificates
cs = StmAux (ConsumedInExp, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ConsumedInExp, ())
StmAux (ExpDec (Aliases Kernels))
aux Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> (DesiredUpdate (ConsumedInExp, Type) -> Certificates)
-> [DesiredUpdate (ConsumedInExp, Type)] -> Certificates
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DesiredUpdate (ConsumedInExp, Type) -> Certificates
forall dec. DesiredUpdate dec -> Certificates
updateCertificates [DesiredUpdate (ConsumedInExp, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates
        [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)])
-> [Stm (Aliases Kernels)] -> m [Stm (Aliases Kernels)]
forall a b. (a -> b) -> a -> b
$
          Certificates -> Stm (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Pattern (Aliases Kernels)
-> StmAux (ExpDec (Aliases Kernels))
-> ExpT (Aliases Kernels)
-> Stm (Aliases Kernels)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (ConsumedInExp, Type)
Pattern (Aliases Kernels)
pat' StmAux (ExpDec (Aliases Kernels))
aux (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Aliases Kernels) -> ExpT (Aliases Kernels))
-> Op (Aliases Kernels) -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases Kernels)
 -> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases Kernels)
-> SegOp SegLevel (Aliases Kernels)
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases Kernels)
kbody') Stm (Aliases Kernels)
-> [Stm (Aliases Kernels)] -> [Stm (Aliases Kernels)]
forall a. a -> [a] -> [a]
:
          Stms (Aliases Kernels) -> [Stm (Aliases Kernels)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases Kernels)
poststms
lowerUpdateKernels Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetDec (Aliases Kernels))]
updates = LowerUpdate Kernels m
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetDec lore ~ Type,
 CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate Scope (Aliases Kernels)
scope Stm (Aliases Kernels)
stm [DesiredUpdate (LetDec (Aliases Kernels))]
updates

lowerUpdatesIntoSegMap ::
  MonadFreshNames m =>
  Scope (Aliases Kernels) ->
  Pattern (Aliases Kernels) ->
  [DesiredUpdate (LetDec (Aliases Kernels))] ->
  SegSpace ->
  KernelBody (Aliases Kernels) ->
  Maybe
    ( m
        ( Pattern (Aliases Kernels),
          KernelBody (Aliases Kernels),
          Stms (Aliases Kernels)
        )
    )
lowerUpdatesIntoSegMap :: Scope (Aliases Kernels)
-> Pattern (Aliases Kernels)
-> [DesiredUpdate (LetDec (Aliases Kernels))]
-> SegSpace
-> KernelBody (Aliases Kernels)
-> Maybe
     (m (Pattern (Aliases Kernels), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
lowerUpdatesIntoSegMap Scope (Aliases Kernels)
scope Pattern (Aliases Kernels)
pat [DesiredUpdate (LetDec (Aliases Kernels))]
updates SegSpace
kspace KernelBody (Aliases Kernels)
kbody = do
  -- The updates are all-or-nothing.  Being more liberal would require
  -- changes to the in-place-lowering pass itself.
  [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))]
mk <- (PatElemT (ConsumedInExp, Type)
 -> KernelResult
 -> Maybe
      (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
          KernelResult, Stms (Aliases Kernels))))
-> [PatElemT (ConsumedInExp, Type)]
-> [KernelResult]
-> Maybe
     [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT (ConsumedInExp, Type)
-> KernelResult
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
onRet (PatternT (ConsumedInExp, Type) -> [PatElemT (ConsumedInExp, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (ConsumedInExp, Type)
Pattern (Aliases Kernels)
pat) (KernelBody (Aliases Kernels) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases Kernels)
kbody)
  m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
   Stms (Aliases Kernels))
-> Maybe
     (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
forall (m :: * -> *) a. Monad m => a -> m a
return (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
    Stms (Aliases Kernels))
 -> Maybe
      (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
          Stms (Aliases Kernels))))
-> m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
      Stms (Aliases Kernels))
-> Maybe
     (m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
         Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
    ([PatElemT (ConsumedInExp, Type)]
pes, [Stms (Aliases Kernels)]
bodystms, [KernelResult]
krets, [Stms (Aliases Kernels)]
poststms) <- [(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
  KernelResult, Stms (Aliases Kernels))]
-> ([PatElemT (ConsumedInExp, Type)], [Stms (Aliases Kernels)],
    [KernelResult], [Stms (Aliases Kernels)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
   KernelResult, Stms (Aliases Kernels))]
 -> ([PatElemT (ConsumedInExp, Type)], [Stms (Aliases Kernels)],
     [KernelResult], [Stms (Aliases Kernels)]))
-> m [(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
       KernelResult, Stms (Aliases Kernels))]
-> m ([PatElemT (ConsumedInExp, Type)], [Stms (Aliases Kernels)],
      [KernelResult], [Stms (Aliases Kernels)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))]
-> m [(PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
       KernelResult, Stms (Aliases Kernels))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))]
mk
    (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
 Stms (Aliases Kernels))
-> m (PatternT (ConsumedInExp, Type), KernelBody (Aliases Kernels),
      Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( [PatElemT (ConsumedInExp, Type)]
-> [PatElemT (ConsumedInExp, Type)]
-> PatternT (ConsumedInExp, Type)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (ConsumedInExp, Type)]
pes,
        KernelBody (Aliases Kernels)
kbody
          { kernelBodyStms :: Stms (Aliases Kernels)
kernelBodyStms = KernelBody (Aliases Kernels) -> Stms (Aliases Kernels)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody (Aliases Kernels)
kbody Stms (Aliases Kernels)
-> Stms (Aliases Kernels) -> Stms (Aliases Kernels)
forall a. Semigroup a => a -> a -> a
<> [Stms (Aliases Kernels)] -> Stms (Aliases Kernels)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases Kernels)]
bodystms,
            kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
          },
        [Stms (Aliases Kernels)] -> Stms (Aliases Kernels)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases Kernels)]
poststms
      )
  where
    ([VName]
gtids, [SubExp]
_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace

    onRet :: PatElemT (ConsumedInExp, Type)
-> KernelResult
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
onRet (PatElem VName
v (ConsumedInExp, Type)
v_dec) KernelResult
ret
      | Just (DesiredUpdate VName
bindee_nm (ConsumedInExp, Type)
bindee_dec Certificates
_cs VName
src Slice SubExp
slice VName
_val) <-
          (DesiredUpdate (ConsumedInExp, Type) -> Bool)
-> [DesiredUpdate (ConsumedInExp, Type)]
-> Maybe (DesiredUpdate (ConsumedInExp, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (DesiredUpdate (ConsumedInExp, Type) -> VName)
-> DesiredUpdate (ConsumedInExp, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (ConsumedInExp, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (ConsumedInExp, Type)]
[DesiredUpdate (LetDec (Aliases Kernels))]
updates = do
        Returns ResultManifest
_ SubExp
se <- KernelResult -> Maybe KernelResult
forall a. a -> Maybe a
Just KernelResult
ret

        -- The slice we're writing per thread must fully cover the
        -- underlying dimensions.
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
          let ([SubExp]
dims', Slice SubExp
slice') =
                [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> ([SubExp], Slice SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> [(SubExp, DimIndex SubExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp, DimIndex SubExp) -> Bool)
-> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((SubExp, DimIndex SubExp) -> Maybe SubExp)
-> (SubExp, DimIndex SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix (DimIndex SubExp -> Maybe SubExp)
-> ((SubExp, DimIndex SubExp) -> DimIndex SubExp)
-> (SubExp, DimIndex SubExp)
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, DimIndex SubExp) -> DimIndex SubExp
forall a b. (a, b) -> b
snd) ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$
                  [SubExp] -> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims ((ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
bindee_dec)) Slice SubExp
slice
           in Shape -> Slice SubExp -> Bool
isFullSlice ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') Slice SubExp
slice'

        m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
   KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a. a -> Maybe a
Just (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))
 -> Maybe
      (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
          KernelResult, Stms (Aliases Kernels))))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
slice', Stms (Aliases Kernels)
bodystms) <-
            (BinderT (Aliases Kernels) m [SubExp]
 -> Scope (Aliases Kernels) -> m ([SubExp], Stms (Aliases Kernels)))
-> Scope (Aliases Kernels)
-> BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels))
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT (Aliases Kernels) m [SubExp]
-> Scope (Aliases Kernels) -> m ([SubExp], Stms (Aliases Kernels))
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope (Aliases Kernels)
scope (BinderT (Aliases Kernels) m [SubExp]
 -> m ([SubExp], Stms (Aliases Kernels)))
-> BinderT (Aliases Kernels) m [SubExp]
-> m ([SubExp], Stms (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$
              (TPrimExp Int64 VName -> BinderT (Aliases Kernels) m SubExp)
-> [TPrimExp Int64 VName] -> BinderT (Aliases Kernels) m [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (String
-> TPrimExp Int64 VName -> BinderT (Aliases Kernels) m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") ([TPrimExp Int64 VName] -> BinderT (Aliases Kernels) m [SubExp])
-> [TPrimExp Int64 VName] -> BinderT (Aliases Kernels) m [SubExp]
forall a b. (a -> b) -> a -> b
$
                Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
                  (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids

          let res_dims :: [SubExp]
res_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall a b. (a, b) -> b
snd (ConsumedInExp, Type)
bindee_dec
              ret' :: KernelResult
ret' = [SubExp] -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns [SubExp]
res_dims VName
src [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]

          (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
 KernelResult, Stms (Aliases Kernels))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( VName -> (ConsumedInExp, Type) -> PatElemT (ConsumedInExp, Type)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
bindee_nm (ConsumedInExp, Type)
bindee_dec,
              Stms (Aliases Kernels)
bodystms,
              KernelResult
ret',
              Stm (Aliases Kernels) -> Stms (Aliases Kernels)
forall lore. Stm lore -> Stms lore
oneStm (Stm (Aliases Kernels) -> Stms (Aliases Kernels))
-> Stm (Aliases Kernels) -> Stms (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
                [Ident]
-> [Ident] -> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (ConsumedInExp, Type) -> Type
forall t. Typed t => t -> Type
typeOf (ConsumedInExp, Type)
v_dec] (ExpT (Aliases Kernels) -> Stm (Aliases Kernels))
-> ExpT (Aliases Kernels) -> Stm (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> ExpT (Aliases Kernels)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Aliases Kernels))
-> BasicOp -> ExpT (Aliases Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice
            )
    onRet PatElemT (ConsumedInExp, Type)
pe KernelResult
ret =
      m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
   KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a. a -> Maybe a
Just (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
    KernelResult, Stms (Aliases Kernels))
 -> Maybe
      (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
          KernelResult, Stms (Aliases Kernels))))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
-> Maybe
     (m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
         KernelResult, Stms (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
 KernelResult, Stms (Aliases Kernels))
-> m (PatElemT (ConsumedInExp, Type), Stms (Aliases Kernels),
      KernelResult, Stms (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (ConsumedInExp, Type)
pe, Stms (Aliases Kernels)
forall a. Monoid a => a
mempty, KernelResult
ret, Stms (Aliases Kernels)
forall a. Monoid a => a
mempty)

lowerUpdateIntoLoop ::
  ( Bindable lore,
    BinderOps lore,
    Aliased lore,
    LetDec lore ~ (als, Type),
    MonadFreshNames m
  ) =>
  Scope lore ->
  [DesiredUpdate (LetDec lore)] ->
  Pattern lore ->
  [(FParam lore, SubExp)] ->
  [(FParam lore, SubExp)] ->
  LoopForm lore ->
  Body lore ->
  Maybe
    ( m
        ( [Stm lore],
          [Stm lore],
          [Ident],
          [Ident],
          [(FParam lore, SubExp)],
          [(FParam lore, SubExp)],
          Body lore
        )
    )
lowerUpdateIntoLoop :: Scope lore
-> [DesiredUpdate (LetDec lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> Body lore
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(FParam lore, SubExp)], [(FParam lore, SubExp)], Body lore))
lowerUpdateIntoLoop Scope lore
scope [DesiredUpdate (LetDec lore)]
updates Pattern lore
pat [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form Body lore
body = do
  -- Algorithm:
  --
  --   0) Map each result of the loop body to a corresponding in-place
  --      update, if one exists.
  --
  --   1) Create new merge variables corresponding to the arrays being
  --      updated; extend the pattern and the @res@ list with these,
  --      and remove the parts of the result list that have a
  --      corresponding in-place update.
  --
  --      (The creation of the new merge variable identifiers is
  --      actually done at the same time as step (0)).
  --
  --   2) Create in-place updates at the end of the loop body.
  --
  --   3) Create index expressions that read back the values written
  --      in (2).  If the merge parameter corresponding to this value
  --      is unique, also @copy@ this value.
  --
  --   4) Update the result of the loop body to properly pass the new
  --      arrays and indexed elements to the next iteration of the
  --      loop.
  --
  -- We also check that the merge parameters we work with have
  -- loop-invariant shapes.

  -- Safety condition (8).
  [((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases Body lore
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`nameIn` Names
als

  m [LoopResultSummary (als, Type)]
mk_in_place_map <- Scope lore
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall lore (m :: * -> *) als.
(Aliased lore, MonadFreshNames m) =>
Scope lore
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope lore
scope [DesiredUpdate (als, Type)]
[DesiredUpdate (LetDec lore)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
val

  m ([Stm lore], [Stm lore], [Ident], [Ident],
   [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a. a -> Maybe a
Just (m ([Stm lore], [Stm lore], [Ident], [Ident],
    [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
 -> Maybe
      (m ([Stm lore], [Stm lore], [Ident], [Ident],
          [(Param DeclType, SubExp)], [(Param DeclType, SubExp)],
          Body lore)))
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
      [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> Maybe
     (m ([Stm lore], [Stm lore], [Ident], [Ident],
         [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore))
forall a b. (a -> b) -> a -> b
$ do
    [LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
    ([(Param DeclType, SubExp)]
val', [Stm lore]
prebnds, [Stm lore]
postbnds) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) lore als.
(MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
    let ([Ident]
ctxpat, [Ident]
valpat) = [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
        idxsubsts :: IndexSubstitutions (als, Type)
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions (als, Type)
forall dec. [LoopResultSummary dec] -> IndexSubstitutions dec
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
    (IndexSubstitutions (als, Type)
idxsubsts', Stms lore
newbnds) <- IndexSubstitutions (als, Type)
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall (m :: * -> *) lore dec.
(MonadFreshNames m, BinderOps lore, Bindable lore, Aliased lore,
 LetDec lore ~ dec) =>
IndexSubstitutions dec
-> Stms lore -> m (IndexSubstitutions dec, Stms lore)
substituteIndices IndexSubstitutions (als, Type)
idxsubsts (Stms lore -> m (IndexSubstitutions (als, Type), Stms lore))
-> Stms lore -> m (IndexSubstitutions (als, Type), Stms lore)
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body
    ([SubExp]
body_res, Stms lore
res_bnds) <- [LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetDec lore)]
in_place_map IndexSubstitutions (als, Type)
IndexSubstitutions (LetDec lore)
idxsubsts'
    let body' :: Body lore
body' = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
newbnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
res_bnds) [SubExp]
body_res
    ([Stm lore], [Stm lore], [Ident], [Ident],
 [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
-> m ([Stm lore], [Stm lore], [Ident], [Ident],
      [(Param DeclType, SubExp)], [(Param DeclType, SubExp)], Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore]
prebnds, [Stm lore]
postbnds, [Ident]
ctxpat, [Ident]
valpat, [(Param DeclType, SubExp)]
[(FParam lore, SubExp)]
ctx, [(Param DeclType, SubExp)]
val', Body lore
body')
  where
    usedInBody :: Names
usedInBody =
      [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope lore -> Names
forall lore.
AliasesOf (LetDec lore) =>
VName -> Scope lore -> Names
`lookupAliases` Scope lore
scope) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
    resmap :: [(SubExp, Ident)]
resmap = [SubExp] -> [Ident] -> [(SubExp, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body) ([Ident] -> [(SubExp, Ident)]) -> [Ident] -> [(SubExp, Ident)]
forall a b. (a -> b) -> a -> b
$ PatternT (als, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternValueIdents PatternT (als, Type)
Pattern lore
pat

    mkMerges ::
      (MonadFreshNames m, Bindable lore) =>
      [LoopResultSummary (als, Type)] ->
      m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
    mkMerges :: [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
      (([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm lore]
prebnds, [Stm lore]
postbnds)) <-
        WriterT
  ([Stm lore], [Stm lore])
  m
  ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm lore], [Stm lore]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Stm lore], [Stm lore])
   m
   ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
 -> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
       ([Stm lore], [Stm lore])))
-> WriterT
     ([Stm lore], [Stm lore])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm lore], [Stm lore]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
 -> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
     ([Stm lore], [Stm lore])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
     ([Stm lore], [Stm lore])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
 -> WriterT
      ([Stm lore], [Stm lore])
      m
      (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
     ([Stm lore], [Stm lore])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LoopResultSummary (als, Type)
-> WriterT
     ([Stm lore], [Stm lore])
     m
     (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) lore lore a.
(MonadFreshNames m, MonadWriter ([Stm lore], [Stm lore]) m,
 Bindable lore, Bindable lore) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
      ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm lore]
prebnds, [Stm lore]
postbnds)

    mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
      | Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
        VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
        let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
            elmident :: Ident
elmident =
              VName -> Type -> Ident
Ident
                (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update)
                (Type
source_t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims (DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update))
        ([Stm lore], [Stm lore]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
          ( [ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> SubExp -> BasicOp
Update
                    (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
                    (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
                    (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
            ],
            [ [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
elmident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index
                    (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
                    (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
source_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
            ]
          )
        Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$
          (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right
            ( VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param
                VName
mergename
                (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
              VName -> SubExp
Var VName
source
            )
      | Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
 -> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary

    mkResAndPat :: [LoopResultSummary (als, Type)] -> ([Ident], [Ident])
mkResAndPat [LoopResultSummary (als, Type)]
summaries =
      let ([Ident]
origpat, [Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (als, Type) -> Either Ident Ident)
-> [LoopResultSummary (als, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (als, Type) -> Either Ident Ident
forall a. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (als, Type)]
summaries
       in ( PatternT (als, Type) -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternContextIdents PatternT (als, Type)
Pattern lore
pat,
            [Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat
          )

    mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
      | Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
        Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
      | Bool
otherwise =
        Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall dec. LoopResultSummary dec -> Ident
inPatternAs LoopResultSummary (a, Type)
summary)

summariseLoop ::
  ( Aliased lore,
    MonadFreshNames m
  ) =>
  Scope lore ->
  [DesiredUpdate (als, Type)] ->
  Names ->
  [(SubExp, Ident)] ->
  [(Param DeclType, SubExp)] ->
  Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: Scope lore
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope lore
scope [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
  [m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m (LoopResultSummary (als, Type))]
 -> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Ident)
 -> (Param DeclType, SubExp)
 -> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExp, Ident)]
resmap [(Param DeclType, SubExp)]
merge
  where
    summariseLoopResult :: (SubExp, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExp
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
      | Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
        -- Safety condition (7)
        if Names
usedInBody Names -> Names -> Bool
`namesIntersect` VName -> Scope lore -> Names
forall lore.
AliasesOf (LetDec lore) =>
VName -> Scope lore -> Names
lookupAliases (DesiredUpdate (als, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update) Scope lore
scope
          then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
          else
            if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam
              then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
 -> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
                VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
                LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return
                  LoopResultSummary :: forall dec.
SubExp
-> Ident
-> (Param DeclType, SubExp)
-> Maybe (DesiredUpdate dec, VName, dec)
-> LoopResultSummary dec
LoopResultSummary
                    { resultSubExp :: SubExp
resultSubExp = SubExp
se,
                      inPatternAs :: Ident
inPatternAs = Ident
v,
                      mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit),
                      relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate =
                        (DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just
                          ( DesiredUpdate (als, Type)
update,
                            VName
lowered_array,
                            DesiredUpdate (als, Type) -> (als, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update
                          )
                    }
              else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
    summariseLoopResult (SubExp, Ident)
_ (Param DeclType, SubExp)
_ =
      Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing -- XXX: conservative; but this entire pass is going away.
    hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType

    merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge

    loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
    loopInvariant Constant {} = Bool
True

data LoopResultSummary dec = LoopResultSummary
  { LoopResultSummary dec -> SubExp
resultSubExp :: SubExp,
    LoopResultSummary dec -> Ident
inPatternAs :: Ident,
    LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp),
    LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
  }
  deriving (Int -> LoopResultSummary dec -> ShowS
[LoopResultSummary dec] -> ShowS
LoopResultSummary dec -> String
(Int -> LoopResultSummary dec -> ShowS)
-> (LoopResultSummary dec -> String)
-> ([LoopResultSummary dec] -> ShowS)
-> Show (LoopResultSummary dec)
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary dec] -> ShowS
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
show :: LoopResultSummary dec -> String
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
Show)

indexSubstitutions ::
  [LoopResultSummary dec] ->
  IndexSubstitutions dec
indexSubstitutions :: [LoopResultSummary dec] -> IndexSubstitutions dec
indexSubstitutions = (LoopResultSummary dec
 -> Maybe (VName, (Certificates, VName, dec, Slice SubExp)))
-> [LoopResultSummary dec] -> IndexSubstitutions dec
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary dec
-> Maybe (VName, (Certificates, VName, dec, Slice SubExp))
forall c.
LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution
  where
    getSubstitution :: LoopResultSummary c
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
getSubstitution LoopResultSummary c
res = do
      (DesiredUpdate VName
_ c
_ Certificates
cs VName
_ Slice SubExp
is VName
_, VName
nm, c
dec) <- LoopResultSummary c -> Maybe (DesiredUpdate c, VName, c)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary c
res
      let name :: VName
name = Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary c -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary c
res
      (VName, (Certificates, VName, c, Slice SubExp))
-> Maybe (VName, (Certificates, VName, c, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
name, (Certificates
cs, VName
nm, c
dec, Slice SubExp
is))

manipulateResult ::
  (Bindable lore, MonadFreshNames m) =>
  [LoopResultSummary (LetDec lore)] ->
  IndexSubstitutions (LetDec lore) ->
  m (Result, Stms lore)
manipulateResult :: [LoopResultSummary (LetDec lore)]
-> IndexSubstitutions (LetDec lore) -> m ([SubExp], Stms lore)
manipulateResult [LoopResultSummary (LetDec lore)]
summaries IndexSubstitutions (LetDec lore)
substs = do
  let ([SubExp]
orig_ses, [SubExp]
updated_ses) = [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExp SubExp] -> ([SubExp], [SubExp]))
-> [Either SubExp SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetDec lore) -> Either SubExp SubExp)
-> [LoopResultSummary (LetDec lore)] -> [Either SubExp SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetDec lore) -> Either SubExp SubExp
forall dec. LoopResultSummary dec -> Either SubExp SubExp
unchangedRes [LoopResultSummary (LetDec lore)]
summaries
  ([SubExp]
subst_ses, [Stm lore]
res_bnds) <- WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore]))
-> WriterT [Stm lore] m [SubExp] -> m ([SubExp], [Stm lore])
forall a b. (a -> b) -> a -> b
$ (SubExp
 -> (VName, (Certificates, VName, LetDec lore, Slice SubExp))
 -> WriterT [Stm lore] m SubExp)
-> [SubExp]
-> IndexSubstitutions (LetDec lore)
-> WriterT [Stm lore] m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp
-> (VName, (Certificates, VName, LetDec lore, Slice SubExp))
-> WriterT [Stm lore] m SubExp
forall (m :: * -> *) lore t.
(MonadFreshNames m, MonadWriter [Stm lore] m, Bindable lore,
 Typed t) =>
SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes [SubExp]
updated_ses IndexSubstitutions (LetDec lore)
substs
  ([SubExp], Stms lore) -> m ([SubExp], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp]
orig_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
subst_ses, [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
res_bnds)
  where
    unchangedRes :: LoopResultSummary dec -> Either SubExp SubExp
unchangedRes LoopResultSummary dec
summary =
      case LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
        Maybe (DesiredUpdate dec, VName, dec)
Nothing -> SubExp -> Either SubExp SubExp
forall a b. a -> Either a b
Left (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExp
forall dec. LoopResultSummary dec -> SubExp
resultSubExp LoopResultSummary dec
summary
        Just (DesiredUpdate dec, VName, dec)
_ -> SubExp -> Either SubExp SubExp
forall a b. b -> Either a b
Right (SubExp -> Either SubExp SubExp) -> SubExp -> Either SubExp SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExp
forall dec. LoopResultSummary dec -> SubExp
resultSubExp LoopResultSummary dec
summary
    substRes :: SubExp
-> (VName, (Certificates, VName, t, Slice SubExp)) -> m SubExp
substRes (Var VName
res_v) (VName
subst_v, (Certificates
_, VName
nm, t
_, Slice SubExp
_))
      | VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
        SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
    substRes SubExp
res_se (VName
_, (Certificates
cs, VName
nm, t
dec, Slice SubExp
is)) = do
      Ident
v' <- ShowS -> Ident -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_updated") (Ident -> m Ident) -> Ident -> m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
dec
      [Stm lore] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
        [ Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Stm lore -> Stm lore) -> Stm lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
            [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
v'] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
                VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
nm (Type -> Slice SubExp -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
dec) Slice SubExp
is) SubExp
res_se
        ]
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'