{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | This module exports facilities for transforming array accesses in
-- a list of 'Stm's (intended to be the bindings in a body).  The
-- idea is that you can state that some variable @x@ is in fact an
-- array indexing @v[i0,i1,...]@.
module Futhark.Optimise.InPlaceLowering.SubstituteIndices
       (
         substituteIndices
       , IndexSubstitution
       , IndexSubstitutions
       ) where

import Control.Monad
import qualified Data.Map.Strict as M

import Futhark.IR.Prop.Aliases
import Futhark.IR
import Futhark.Construct
import Futhark.Util

type IndexSubstitution dec = (Certificates, VName, dec, Slice SubExp)
type IndexSubstitutions dec = [(VName, IndexSubstitution dec)]

typeEnvFromSubstitutions :: LetDec lore ~ dec =>
                            IndexSubstitutions dec -> Scope lore
typeEnvFromSubstitutions :: IndexSubstitutions dec -> Scope lore
typeEnvFromSubstitutions = [(VName, NameInfo lore)] -> Scope lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Scope lore)
-> (IndexSubstitutions dec -> [(VName, NameInfo lore)])
-> IndexSubstitutions dec
-> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, IndexSubstitution dec) -> (VName, NameInfo lore))
-> IndexSubstitutions dec -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (IndexSubstitution dec -> (VName, NameInfo lore)
forall a a lore d. (a, a, LetDec lore, d) -> (a, NameInfo lore)
fromSubstitution (IndexSubstitution dec -> (VName, NameInfo lore))
-> ((VName, IndexSubstitution dec) -> IndexSubstitution dec)
-> (VName, IndexSubstitution dec)
-> (VName, NameInfo lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, IndexSubstitution dec) -> IndexSubstitution dec
forall a b. (a, b) -> b
snd)
  where fromSubstitution :: (a, a, LetDec lore, d) -> (a, NameInfo lore)
fromSubstitution (a
_, a
name, LetDec lore
t, d
_) =
          (a
name, LetDec lore -> NameInfo lore
forall lore. LetDec lore -> NameInfo lore
LetName LetDec lore
t)

-- | Perform the substitution.
substituteIndices :: (MonadFreshNames m, BinderOps lore, Bindable lore,
                      Aliased lore, LetDec lore ~ dec) =>
                     IndexSubstitutions dec -> Stms lore
                  -> m (IndexSubstitutions dec, Stms lore)
substituteIndices :: IndexSubstitutions dec
-> Stms lore -> m (IndexSubstitutions dec, Stms lore)
substituteIndices IndexSubstitutions dec
substs Stms lore
bnds =
  BinderT lore m (IndexSubstitutions dec)
-> Scope lore -> m (IndexSubstitutions dec, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (IndexSubstitutions (LetDec (Lore (BinderT lore m)))
-> Stms (Lore (BinderT lore m))
-> BinderT
     lore m (IndexSubstitutions (LetDec (Lore (BinderT lore m))))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetDec (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStms IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore (BinderT lore m)))
substs Stms lore
Stms (Lore (BinderT lore m))
bnds) Scope lore
types
  where types :: Scope lore
types = IndexSubstitutions dec -> Scope lore
forall lore dec.
(LetDec lore ~ dec) =>
IndexSubstitutions dec -> Scope lore
typeEnvFromSubstitutions IndexSubstitutions dec
substs

substituteIndicesInStms :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
                           IndexSubstitutions (LetDec (Lore m))
                        -> Stms (Lore m)
                        -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStms :: IndexSubstitutions (LetDec (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStms = (IndexSubstitutions (LetDec (Lore m))
 -> Stm (Lore m) -> m (IndexSubstitutions (LetDec (Lore m))))
-> IndexSubstitutions (LetDec (Lore m))
-> Stms (Lore m)
-> m (IndexSubstitutions (LetDec (Lore m)))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions (LetDec (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetDec (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStm

substituteIndicesInStm :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
                          IndexSubstitutions (LetDec (Lore m))
                       -> Stm (Lore m)
                       -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStm :: IndexSubstitutions (LetDec (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStm IndexSubstitutions (LetDec (Lore m))
substs (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
lore Exp (Lore m)
e) = do
  Exp (Lore m)
e' <- IndexSubstitutions (LetDec (Lore m))
-> Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) dec.
(MonadBinder m, Bindable (Lore m), Aliased (Lore m),
 LetDec (Lore m) ~ dec) =>
IndexSubstitutions (LetDec (Lore m))
-> Exp (Lore m) -> m (Exp (Lore m))
substituteIndicesInExp IndexSubstitutions (LetDec (Lore m))
substs Exp (Lore m)
e
  (IndexSubstitutions (LetDec (Lore m))
substs', Pattern (Lore m)
pat') <- IndexSubstitutions (LetDec (Lore m))
-> Pattern (Lore m)
-> m (IndexSubstitutions (LetDec (Lore m)), Pattern (Lore m))
forall (m :: * -> *) dec.
(MonadBinder m, LetDec (Lore m) ~ dec) =>
IndexSubstitutions (LetDec (Lore m))
-> PatternT dec
-> m (IndexSubstitutions (LetDec (Lore m)), PatternT dec)
substituteIndicesInPattern IndexSubstitutions (LetDec (Lore m))
substs Pattern (Lore m)
pat
  Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> Stm (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m)
-> StmAux (ExpDec (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Lore m)
pat' StmAux (ExpDec (Lore m))
lore Exp (Lore m)
e'
  IndexSubstitutions (LetDec (Lore m))
-> m (IndexSubstitutions (LetDec (Lore m)))
forall (m :: * -> *) a. Monad m => a -> m a
return IndexSubstitutions (LetDec (Lore m))
substs'

substituteIndicesInPattern :: (MonadBinder m, LetDec (Lore m) ~ dec) =>
                              IndexSubstitutions (LetDec (Lore m))
                           -> PatternT dec
                           -> m (IndexSubstitutions (LetDec (Lore m)), PatternT dec)
substituteIndicesInPattern :: IndexSubstitutions (LetDec (Lore m))
-> PatternT dec
-> m (IndexSubstitutions (LetDec (Lore m)), PatternT dec)
substituteIndicesInPattern IndexSubstitutions (LetDec (Lore m))
substs PatternT dec
pat = do
  (IndexSubstitutions dec
substs', [PatElemT dec]
context) <- (IndexSubstitutions dec
 -> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec))
-> IndexSubstitutions dec
-> [PatElemT dec]
-> m (IndexSubstitutions dec, [PatElemT dec])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM IndexSubstitutions dec
-> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
sub IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore m))
substs ([PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec]))
-> [PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec])
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat
  (IndexSubstitutions dec
substs'', [PatElemT dec]
values) <- (IndexSubstitutions dec
 -> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec))
-> IndexSubstitutions dec
-> [PatElemT dec]
-> m (IndexSubstitutions dec, [PatElemT dec])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM IndexSubstitutions dec
-> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
sub IndexSubstitutions dec
substs' ([PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec]))
-> [PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec])
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat
  (IndexSubstitutions dec, PatternT dec)
-> m (IndexSubstitutions dec, PatternT dec)
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexSubstitutions dec
substs'', [PatElemT dec] -> [PatElemT dec] -> PatternT dec
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT dec]
context [PatElemT dec]
values)
  where sub :: a -> b -> m (a, b)
sub a
substs' b
patElem = (a, b) -> m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
substs', b
patElem)

substituteIndicesInExp :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m),
                           LetDec (Lore m) ~ dec) =>
                          IndexSubstitutions (LetDec (Lore m))
                       -> Exp (Lore m)
                       -> m (Exp (Lore m))
substituteIndicesInExp :: IndexSubstitutions (LetDec (Lore m))
-> Exp (Lore m) -> m (Exp (Lore m))
substituteIndicesInExp IndexSubstitutions (LetDec (Lore m))
substs Exp (Lore m)
e = do
  IndexSubstitutions dec
substs' <- Exp (Lore m) -> m (IndexSubstitutions dec)
copyAnyConsumed Exp (Lore m)
e
  let substitute :: Mapper (Lore m) (Lore m) m
substitute = Mapper (Lore m) (Lore m) m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnSubExp :: SubExp -> m SubExp
mapOnSubExp = IndexSubstitutions (LetDec (Lore m)) -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Lore m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore m))
substs'
                                  , mapOnVName :: VName -> m VName
mapOnVName  = IndexSubstitutions (LetDec (Lore m)) -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Lore m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore m))
substs'
                                  , mapOnBody :: Scope (Lore m) -> Body (Lore m) -> m (Body (Lore m))
mapOnBody   = (Body (Lore m) -> m (Body (Lore m)))
-> Scope (Lore m) -> Body (Lore m) -> m (Body (Lore m))
forall a b. a -> b -> a
const ((Body (Lore m) -> m (Body (Lore m)))
 -> Scope (Lore m) -> Body (Lore m) -> m (Body (Lore m)))
-> (Body (Lore m) -> m (Body (Lore m)))
-> Scope (Lore m)
-> Body (Lore m)
-> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions (LetDec (Lore m))
-> Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetDec (Lore m))
-> Body (Lore m) -> m (Body (Lore m))
substituteIndicesInBody IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore m))
substs'
                                  }

  Mapper (Lore m) (Lore m) m -> Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper (Lore m) (Lore m) m
substitute Exp (Lore m)
e
  where copyAnyConsumed :: Exp (Lore m) -> m (IndexSubstitutions dec)
copyAnyConsumed =
          let consumingSubst :: IndexSubstitutions dec -> VName -> m (IndexSubstitutions dec)
consumingSubst IndexSubstitutions dec
substs' VName
v
                | Just (Certificates
cs2, VName
src2, dec
src2dec, [DimIndex SubExp]
is2) <- VName
-> IndexSubstitutions dec
-> Maybe (Certificates, VName, dec, [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore m))
substs = do
                    VName
row <- Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
                           String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
                           BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
src2 ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (dec -> Type
forall t. Typed t => t -> Type
typeOf dec
src2dec) [DimIndex SubExp]
is2
                    VName
row_copy <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row_copy") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
                                BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
row
                    IndexSubstitutions dec -> m (IndexSubstitutions dec)
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexSubstitutions dec -> m (IndexSubstitutions dec))
-> IndexSubstitutions dec -> m (IndexSubstitutions dec)
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> (Certificates, VName, dec, [DimIndex SubExp])
-> IndexSubstitutions dec
-> IndexSubstitutions dec
forall dec.
VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
update VName
v VName
v (Certificates
forall a. Monoid a => a
mempty,
                                         VName
row_copy,
                                         dec
src2dec dec -> Type -> dec
forall a. SetType a => a -> Type -> a
`setType`
                                         (dec -> Type
forall t. Typed t => t -> Type
typeOf dec
src2dec Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims`
                                          [DimIndex SubExp] -> [SubExp]
forall d. Slice d -> [d]
sliceDims [DimIndex SubExp]
is2),
                                         []) IndexSubstitutions dec
substs'
              consumingSubst IndexSubstitutions dec
substs' VName
_ =
                IndexSubstitutions dec -> m (IndexSubstitutions dec)
forall (m :: * -> *) a. Monad m => a -> m a
return IndexSubstitutions dec
substs'
          in (IndexSubstitutions dec -> VName -> m (IndexSubstitutions dec))
-> IndexSubstitutions dec -> [VName] -> m (IndexSubstitutions dec)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions dec -> VName -> m (IndexSubstitutions dec)
consumingSubst IndexSubstitutions dec
IndexSubstitutions (LetDec (Lore m))
substs ([VName] -> m (IndexSubstitutions dec))
-> (Exp (Lore m) -> [VName])
-> Exp (Lore m)
-> m (IndexSubstitutions dec)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> (Exp (Lore m) -> Names) -> Exp (Lore m) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Lore m) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp

substituteIndicesInSubExp :: MonadBinder m =>
                             IndexSubstitutions (LetDec (Lore m))
                          -> SubExp
                          -> m SubExp
substituteIndicesInSubExp :: IndexSubstitutions (LetDec (Lore m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions (LetDec (Lore m))
substs (Var VName
v) =
  VName -> SubExp
Var (VName -> SubExp) -> m VName -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IndexSubstitutions (LetDec (Lore m)) -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Lore m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions (LetDec (Lore m))
substs VName
v
substituteIndicesInSubExp IndexSubstitutions (LetDec (Lore m))
_ SubExp
se =
  SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

substituteIndicesInVar :: MonadBinder m =>
                          IndexSubstitutions (LetDec (Lore m))
                       -> VName
                       -> m VName
substituteIndicesInVar :: IndexSubstitutions (LetDec (Lore m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions (LetDec (Lore m))
substs VName
v
  | Just (Certificates
cs2, VName
src2, LetDec (Lore m)
_, []) <- VName
-> IndexSubstitutions (LetDec (Lore m))
-> Maybe (Certificates, VName, LetDec (Lore m), [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions (LetDec (Lore m))
substs =
    Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
    String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
src2) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src2
  | Just (Certificates
cs2, VName
src2, LetDec (Lore m)
src2_dec, [DimIndex SubExp]
is2) <- VName
-> IndexSubstitutions (LetDec (Lore m))
-> Maybe (Certificates, VName, LetDec (Lore m), [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions (LetDec (Lore m))
substs =
    Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
    String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"idx" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
src2 ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (LetDec (Lore m) -> Type
forall t. Typed t => t -> Type
typeOf LetDec (Lore m)
src2_dec) [DimIndex SubExp]
is2
  | Bool
otherwise =
    VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

substituteIndicesInBody :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
                           IndexSubstitutions (LetDec (Lore m))
                        -> Body (Lore m)
                        -> m (Body (Lore m))
substituteIndicesInBody :: IndexSubstitutions (LetDec (Lore m))
-> Body (Lore m) -> m (Body (Lore m))
substituteIndicesInBody IndexSubstitutions (LetDec (Lore m))
substs (Body BodyDec (Lore m)
_ Stms (Lore m)
stms [SubExp]
res) = do
  (IndexSubstitutions (LetDec (Lore m))
substs', Stms (Lore m)
stms') <- Stms (Lore m)
-> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
-> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms (Lore m)
stms (m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
 -> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m)))
-> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
-> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
    m (IndexSubstitutions (LetDec (Lore m)))
-> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m (IndexSubstitutions (LetDec (Lore m)))
 -> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m)))
-> m (IndexSubstitutions (LetDec (Lore m)))
-> m (IndexSubstitutions (LetDec (Lore m)), Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions (LetDec (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetDec (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetDec (Lore m)))
substituteIndicesInStms IndexSubstitutions (LetDec (Lore m))
substs Stms (Lore m)
stms
  ([SubExp]
res', Stms (Lore m)
res_stms) <- Stms (Lore m)
-> m ([SubExp], Stms (Lore m)) -> m ([SubExp], Stms (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms (Lore m)
stms' (m ([SubExp], Stms (Lore m)) -> m ([SubExp], Stms (Lore m)))
-> m ([SubExp], Stms (Lore m)) -> m ([SubExp], Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
    m [SubExp] -> m ([SubExp], Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m [SubExp] -> m ([SubExp], Stms (Lore m)))
-> m [SubExp] -> m ([SubExp], Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IndexSubstitutions (LetDec (Lore m)) -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Lore m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions (LetDec (Lore m))
substs') [SubExp]
res
  Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM (Stms (Lore m)
stms'Stms (Lore m) -> Stms (Lore m) -> Stms (Lore m)
forall a. Semigroup a => a -> a -> a
<>Stms (Lore m)
res_stms) [SubExp]
res'

update :: VName -> VName -> IndexSubstitution dec -> IndexSubstitutions dec
       -> IndexSubstitutions dec
update :: VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
update VName
needle VName
name IndexSubstitution dec
subst ((VName
othername, IndexSubstitution dec
othersubst) : IndexSubstitutions dec
substs)
  | VName
needle VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
othername = (VName
name, IndexSubstitution dec
subst)           (VName, IndexSubstitution dec)
-> IndexSubstitutions dec -> IndexSubstitutions dec
forall a. a -> [a] -> [a]
: IndexSubstitutions dec
substs
  | Bool
otherwise           = (VName
othername, IndexSubstitution dec
othersubst) (VName, IndexSubstitution dec)
-> IndexSubstitutions dec -> IndexSubstitutions dec
forall a. a -> [a] -> [a]
: VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
forall dec.
VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
update VName
needle VName
name IndexSubstitution dec
subst IndexSubstitutions dec
substs
update VName
needle VName
_    IndexSubstitution dec
_ [] = String -> IndexSubstitutions dec
forall a. HasCallStack => String -> a
error (String -> IndexSubstitutions dec)
-> String -> IndexSubstitutions dec
forall a b. (a -> b) -> a -> b
$ String
"Cannot find substitution for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
needle