{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.InPlaceLowering.LowerIntoStm
  ( lowerUpdateGPU,
    lowerUpdate,
    LowerUpdate,
    DesiredUpdate (..),
  )
where

import Control.Monad
import Control.Monad.Writer
import Data.Either
import Data.List (find, unzip4)
import Data.Maybe (isNothing, mapMaybe)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.Optimise.InPlaceLowering.SubstituteIndices

data DesiredUpdate dec = DesiredUpdate
  { -- | Name of result.
    forall dec. DesiredUpdate dec -> VName
updateName :: VName,
    -- | Type of result.
    forall dec. DesiredUpdate dec -> dec
updateType :: dec,
    forall dec. DesiredUpdate dec -> Certs
updateCerts :: Certs,
    forall dec. DesiredUpdate dec -> VName
updateSource :: VName,
    forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp,
    forall dec. DesiredUpdate dec -> VName
updateValue :: VName
  }
  deriving (Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate dec] -> ShowS
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
show :: DesiredUpdate dec -> String
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
Show)

instance Functor DesiredUpdate where
  a -> b
f fmap :: forall a b. (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u {updateType :: b
updateType = a -> b
f forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u}

updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
name ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue

type LowerUpdate rep m =
  Scope (Aliases rep) ->
  Stm (Aliases rep) ->
  [DesiredUpdate (LetDec (Aliases rep))] ->
  Maybe (m [Stm (Aliases rep)])

lowerUpdate ::
  ( MonadFreshNames m,
    Buildable rep,
    LetDec rep ~ Type,
    CanBeAliased (Op rep)
  ) =>
  LowerUpdate rep m
lowerUpdate :: forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 CanBeAliased (Op rep)) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases rep)
scope (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body)) [DesiredUpdate (LetDec (Aliases rep))]
updates = do
  m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
   [(Param DeclType, SubExp)], Body (Aliases rep))
canDo <- forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
 LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
     (m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
         Body rep))
lowerUpdateIntoLoop Scope (Aliases rep)
scope [DesiredUpdate (LetDec (Aliases rep))]
updates Pat (LetDec (Aliases rep))
pat [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
    ([Stm (Aliases rep)]
prestms, [Stm (Aliases rep)]
poststms, [Ident]
pat', [(Param DeclType, SubExp)]
merge', Body (Aliases rep)
body') <- m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
   [(Param DeclType, SubExp)], Body (Aliases rep))
canDo
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
      [Stm (Aliases rep)]
prestms
        forall a. [a] -> [a] -> [a]
++ [ forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Aliases rep))
aux) forall a b. (a -> b) -> a -> b
$
               forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [Ident]
pat' forall a b. (a -> b) -> a -> b
$
                 forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge' LoopForm (Aliases rep)
form Body (Aliases rep)
body'
           ]
        forall a. [a] -> [a] -> [a]
++ [Stm (Aliases rep)]
poststms
lowerUpdate
  Scope (Aliases rep)
_
  (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (SubExp (Var VName
v))))
  [DesiredUpdate VName
bindee_nm LetDec (Aliases rep)
bindee_dec Certs
cs VName
src (Slice [DimIndex SubExp]
is) VName
val]
    | forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Aliases rep))
pat forall a. Eq a => a -> a -> Bool
== [VName
src] =
        let is' :: Slice SubExp
is' = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf LetDec (Aliases rep)
bindee_dec) [DimIndex SubExp]
is
         in forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
              [ forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Aliases rep))
aux forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
bindee_nm forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf LetDec (Aliases rep)
bindee_dec] forall a b. (a -> b) -> a -> b
$
                    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                      Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
v Slice SubExp
is' forall a b. (a -> b) -> a -> b
$
                        VName -> SubExp
Var VName
val
              ]
lowerUpdate Scope (Aliases rep)
_ Stm (Aliases rep)
_ [DesiredUpdate (LetDec (Aliases rep))]
_ =
  forall a. Maybe a
Nothing

lowerUpdateGPU :: MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU :: forall (m :: * -> *). MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU
  Scope (Aliases GPU)
scope
  (Let Pat (LetDec (Aliases GPU))
pat StmAux (ExpDec (Aliases GPU))
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody))))
  [DesiredUpdate (LetDec (Aliases GPU))]
updates
    | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Aliases GPU))
pat) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (LetDec (Aliases GPU))]
updates,
      Bool -> Bool
not Bool
source_used_in_kbody = do
        m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
   Stms (Aliases GPU))
mk <- forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
     (m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
space KernelBody (Aliases GPU)
kbody
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
          (Pat (VarAliases, Type)
pat', KernelBody (Aliases GPU)
kbody', Stms (Aliases GPU)
poststms) <- m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
   Stms (Aliases GPU))
mk
          let cs :: Certs
cs = forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Aliases GPU))
aux forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall dec. DesiredUpdate dec -> Certs
updateCerts [DesiredUpdate (LetDec (Aliases GPU))]
updates
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (VarAliases, Type)
pat' StmAux (ExpDec (Aliases GPU))
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody')
              forall a. a -> [a] -> [a]
: forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms (Aliases GPU)
poststms
    where
      -- This check is a bit more conservative than ideal.  In a perfect
      -- world, we would allow indexing a[i,j] if the update is also
      -- to exactly a[i,j], as that will not create cross-iteration
      -- dependencies.  (Although the type checker wouldn't be able to
      -- permit this anyway.)
      source_used_in_kbody :: Bool
source_used_in_kbody =
        forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k).
AliasesOf (LetDec rep) =>
VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) (Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases GPU)
kbody)))
          Names -> Names -> Bool
`namesIntersect` forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map ((forall {k} (rep :: k).
AliasesOf (LetDec rep) =>
VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateSource) [DesiredUpdate (LetDec (Aliases GPU))]
updates)
lowerUpdateGPU Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates = forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 CanBeAliased (Op rep)) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates

lowerUpdatesIntoSegMap ::
  MonadFreshNames m =>
  Scope (Aliases GPU) ->
  Pat (LetDec (Aliases GPU)) ->
  [DesiredUpdate (LetDec (Aliases GPU))] ->
  SegSpace ->
  KernelBody (Aliases GPU) ->
  Maybe
    ( m
        ( Pat (LetDec (Aliases GPU)),
          KernelBody (Aliases GPU),
          Stms (Aliases GPU)
        )
    )
lowerUpdatesIntoSegMap :: forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
     (m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
kspace KernelBody (Aliases GPU)
kbody = do
  -- The updates are all-or-nothing.  Being more liberal would require
  -- changes to the in-place-lowering pass itself.
  [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))]
mk <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
onRet (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases GPU))
pat) (forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases GPU)
kbody)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
    ([PatElem (VarAliases, Type)]
pes, [Stms (Aliases GPU)]
bodystms, [KernelResult]
krets, [Stms (Aliases GPU)]
poststms) <- forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))]
mk
    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarAliases, Type)]
pes,
        KernelBody (Aliases GPU)
kbody
          { kernelBodyStms :: Stms (Aliases GPU)
kernelBodyStms = forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPU)
kbody forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
bodystms,
            kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
          },
        forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
poststms
      )
  where
    ([VName]
gtids, [SubExp]
_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace

    onRet :: PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
onRet (PatElem VName
v (VarAliases, Type)
v_dec) KernelResult
ret
      | Just (DesiredUpdate VName
bindee_nm (VarAliases, Type)
bindee_dec Certs
_cs VName
src Slice SubExp
slice VName
_val) <-
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (LetDec (Aliases GPU))]
updates = do
          Returns ResultManifest
_ Certs
cs SubExp
se <- forall a. a -> Maybe a
Just KernelResult
ret

          -- The slice we're writing per thread must fully cover the
          -- underlying dimensions.
          forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
            let ([SubExp]
dims', [DimIndex SubExp]
slice') =
                  forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. DimIndex d -> Maybe d
dimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
                    forall a b. [a] -> [b] -> [(a, b)]
zip (forall u. TypeBase Shape u -> [SubExp]
arrayDims (forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
bindee_dec)) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
             in Shape -> Slice SubExp -> Bool
isFullSlice (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')

          forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
            ([SubExp]
slice', Stms (Aliases GPU)
bodystms) <-
              forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope (Aliases GPU)
scope forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") forall a b. (a -> b) -> a -> b
$
                  forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
                    forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids

            let res_dims :: [SubExp]
res_dims = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice') forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd (VarAliases, Type)
bindee_dec
                ret' :: KernelResult
ret' = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs (forall d. [d] -> ShapeBase d
Shape [SubExp]
res_dims) VName
src [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]

            VName
v_aliased <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
v

            forall (f :: * -> *) a. Applicative f => a -> f a
pure
              ( forall dec. VName -> dec -> PatElem dec
PatElem VName
bindee_nm (VarAliases, Type)
bindee_dec,
                Stms (Aliases GPU)
bodystms,
                KernelResult
ret',
                forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList
                  [ forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v_aliased forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice,
                    forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_aliased
                  ]
              )
    onRet PatElem (VarAliases, Type)
pe KernelResult
ret =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (VarAliases, Type)
pe, forall a. Monoid a => a
mempty, KernelResult
ret, forall a. Monoid a => a
mempty)

lowerUpdateIntoLoop ::
  ( Buildable rep,
    BuilderOps rep,
    Aliased rep,
    LetDec rep ~ (als, Type),
    MonadFreshNames m
  ) =>
  Scope rep ->
  [DesiredUpdate (LetDec rep)] ->
  Pat (LetDec rep) ->
  [(FParam rep, SubExp)] ->
  LoopForm rep ->
  Body rep ->
  Maybe
    ( m
        ( [Stm rep],
          [Stm rep],
          [Ident],
          [(FParam rep, SubExp)],
          Body rep
        )
    )
lowerUpdateIntoLoop :: forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
 LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
     (m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
         Body rep))
lowerUpdateIntoLoop Scope rep
scope [DesiredUpdate (LetDec rep)]
updates Pat (LetDec rep)
pat [(FParam rep, SubExp)]
val LoopForm rep
form Body rep
body = do
  -- Algorithm:
  --
  --   0) Map each result of the loop body to a corresponding in-place
  --      update, if one exists.
  --
  --   1) Create new merge variables corresponding to the arrays being
  --      updated; extend the pattern and the @res@ list with these,
  --      and remove the parts of the result list that have a
  --      corresponding in-place update.
  --
  --      (The creation of the new merge variable identifiers is
  --      actually done at the same time as step (0)).
  --
  --   2) Create in-place updates at the end of the loop body.
  --
  --   3) Create index expressions that read back the values written
  --      in (2).  If the merge parameter corresponding to this value
  --      is unique, also @copy@ this value.
  --
  --   4) Update the result of the loop body to properly pass the new
  --      arrays and indexed elements to the next iteration of the
  --      loop.
  --
  -- We also check that the merge parameters we work with have
  -- loop-invariant shapes.

  -- Safety condition (8).
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam rep, SubExp)]
val forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body) forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`notNameIn` Names
als

  m [LoopResultSummary (als, Type)]
mk_in_place_map <- forall {k} (rep :: k) (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (LetDec rep)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(FParam rep, SubExp)]
val

  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
    [LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
    ([(Param DeclType, SubExp)]
val', [Stm rep]
prestms, [Stm rep]
poststms) <- forall {k} (m :: * -> *) (rep :: k) als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
    let valpat :: [Ident]
valpat = forall {a}. [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
        idxsubsts :: IndexSubstitutions
idxsubsts = forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
    (IndexSubstitutions
idxsubsts', Stms rep
newstms) <- forall (m :: * -> *) rep.
(MonadFreshNames m, BuilderOps rep, Buildable rep, Aliased rep) =>
IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
substituteIndices IndexSubstitutions
idxsubsts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
    (Result
body_res, Stms rep
res_stms) <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (als, Type)]
in_place_map IndexSubstitutions
idxsubsts'
    let body' :: Body rep
body' = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (Stms rep
newstms forall a. Semigroup a => a -> a -> a
<> Stms rep
res_stms) Result
body_res
    forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm rep]
prestms, [Stm rep]
poststms, [Ident]
valpat, [(Param DeclType, SubExp)]
val', Body rep
body')
  where
    usedInBody :: Names
usedInBody =
      forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k).
AliasesOf (LetDec rep) =>
VName -> Scope rep -> Names
`lookupAliases` Scope rep
scope) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Body rep
body forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form
    resmap :: [(SubExpRes, Ident)]
resmap = forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
body) forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (LetDec rep)
pat

    mkMerges ::
      (MonadFreshNames m, Buildable rep) =>
      [LoopResultSummary (als, Type)] ->
      m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
    mkMerges :: forall {k} (m :: * -> *) (rep :: k) als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
      (([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm rep]
prestms, [Stm rep]
poststms)) <-
        forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall a b. [Either a b] -> ([a], [b])
partitionEithers forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} {k} {m :: * -> *} {rep :: k} {rep :: k} {a}.
(MonadFreshNames m, MonadWriter ([Stm rep], [Stm rep]) m,
 Buildable rep, Buildable rep) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param DeclType, SubExp)]
origmerge forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm rep]
prestms, [Stm rep]
poststms)

    mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
      | Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
          VName
source <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
          VName
precopy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
          let source_t :: Type
source_t = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
              elm_t :: Type
elm_t = Type
source_t forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` forall d. Slice d -> [d]
sliceDims (forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
          forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [ forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
source Type
source_t] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp
                  forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update
                    Safety
Unsafe
                    (forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
                  forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd
                  forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
              ],
              [ forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
precopy Type
elm_t] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index
                    (forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update),
                forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident (forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) Type
elm_t] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
              ]
            )
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall a b. b -> Either a b
Right
              ( forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
mergename (forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
                VName -> SubExp
Var VName
source
              )
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary

    mkResAndPat :: [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (a, Type)]
summaries =
      let ([Ident]
origpat, [Ident]
extrapat) = forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {a}. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (a, Type)]
summaries
       in [Ident]
origpat forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat

    mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
      | Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
          forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) (forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
      | Bool
otherwise =
          forall a b. a -> Either a b
Left (forall dec. LoopResultSummary dec -> Ident
inPatAs LoopResultSummary (a, Type)
summary)

summariseLoop ::
  ( Aliased rep,
    MonadFreshNames m
  ) =>
  Scope rep ->
  [DesiredUpdate (als, Type)] ->
  Names ->
  [(SubExpRes, Ident)] ->
  [(Param DeclType, SubExp)] ->
  Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: forall {k} (rep :: k) (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
  forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge
  where
    summariseLoopResult :: (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExpRes
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
      | Just DesiredUpdate (als, Type)
update <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
          -- Safety condition (7)
          if Names
usedInBody Names -> Names -> Bool
`namesIntersect` forall {k} (rep :: k).
AliasesOf (LetDec rep) =>
VName -> Scope rep -> Names
lookupAliases (forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update) Scope rep
scope
            then forall a. Maybe a
Nothing
            else
              if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam
                then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
                  VName
lowered_array <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    LoopResultSummary
                      { resultSubExp :: SubExpRes
resultSubExp = SubExpRes
se,
                        inPatAs :: Ident
inPatAs = Ident
v,
                        mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit),
                        relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate =
                          forall a. a -> Maybe a
Just
                            ( DesiredUpdate (als, Type)
update,
                              VName
lowered_array,
                              forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update
                            )
                      }
                else forall a. Maybe a
Nothing
    summariseLoopResult (SubExpRes, Ident)
_ (Param DeclType, SubExp)
_ =
      forall a. Maybe a
Nothing -- XXX: conservative; but this entire pass is going away.
    hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> [SubExp]
arrayDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType

    merge_param_names :: [VName]
merge_param_names = forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge

    loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
    loopInvariant Constant {} = Bool
True

data LoopResultSummary dec = LoopResultSummary
  { forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp :: SubExpRes,
    forall dec. LoopResultSummary dec -> Ident
inPatAs :: Ident,
    forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp),
    forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
  }
  deriving (Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary dec] -> ShowS
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
show :: LoopResultSummary dec -> String
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
Show)

indexSubstitutions :: Typed dec => [LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions :: forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {t}.
Typed t =>
LoopResultSummary t
-> Maybe (VName, (Certs, VName, Type, Slice SubExp))
getSubstitution
  where
    getSubstitution :: LoopResultSummary t
-> Maybe (VName, (Certs, VName, Type, Slice SubExp))
getSubstitution LoopResultSummary t
res = do
      (DesiredUpdate VName
_ t
_ Certs
cs VName
_ Slice SubExp
is VName
_, VName
nm, t
dec) <- forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary t
res
      let name :: VName
name = forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary t
res
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
name, (Certs
cs, VName
nm, forall t. Typed t => t -> Type
typeOf t
dec, Slice SubExp
is))

manipulateResult ::
  (Buildable rep, MonadFreshNames m) =>
  [LoopResultSummary (LetDec rep)] ->
  IndexSubstitutions ->
  m (Result, Stms rep)
manipulateResult :: forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (LetDec rep)]
summaries IndexSubstitutions
substs = do
  let (Result
orig_ses, Result
updated_ses) = forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {dec}. LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes [LoopResultSummary (LetDec rep)]
summaries
  (Result
subst_ses, [Stm rep]
res_stms) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {k} {f :: * -> *} {rep :: k} {t}.
(MonadFreshNames f, MonadWriter [Stm rep] f, Buildable rep,
 Typed t) =>
SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes Result
updated_ses IndexSubstitutions
substs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
orig_ses forall a. [a] -> [a] -> [a]
++ Result
subst_ses, forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm rep]
res_stms)
  where
    unchangedRes :: LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes LoopResultSummary dec
summary =
      case forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
        Maybe (DesiredUpdate dec, VName, dec)
Nothing -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
        Just (DesiredUpdate dec, VName, dec)
_ -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
    substRes :: SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes (SubExpRes Certs
res_cs (Var VName
res_v)) (VName
subst_v, (Certs
_, VName
nm, t
_, Slice SubExp
_))
      | VName
res_v forall a. Eq a => a -> a -> Bool
== VName
subst_v =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
res_cs forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
    substRes (SubExpRes Certs
res_cs SubExp
res_se) (VName
_, (Certs
cs, VName
nm, t
dec, Slice [DimIndex SubExp]
is)) = do
      Ident
v' <- forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (forall a. [a] -> [a] -> [a]
++ String
"_updated") forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf t
dec
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
        [ forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (Certs
res_cs forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [Ident
v'] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
nm (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf t
dec) [DimIndex SubExp]
is) SubExp
res_se
        ]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'