{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module exports facilities for transforming array accesses in
-- a list of 'Stm's (intended to be the bindings in a body).  The
-- idea is that you can state that some variable @x@ is in fact an
-- array indexing @v[i0,i1,...]@.
module Futhark.Optimise.InPlaceLowering.SubstituteIndices
  ( substituteIndices,
    IndexSubstitution,
    IndexSubstitutions,
  )
where

import Control.Monad
import qualified Data.Map.Strict as M
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Transform.Substitute

type IndexSubstitution = (Certs, VName, Type, Slice SubExp)

type IndexSubstitutions = [(VName, IndexSubstitution)]

typeEnvFromSubstitutions :: LParamInfo rep ~ Type => IndexSubstitutions -> Scope rep
typeEnvFromSubstitutions :: IndexSubstitutions -> Scope rep
typeEnvFromSubstitutions = [(VName, NameInfo rep)] -> Scope rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Scope rep)
-> (IndexSubstitutions -> [(VName, NameInfo rep)])
-> IndexSubstitutions
-> Scope rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, IndexSubstitution) -> (VName, NameInfo rep))
-> IndexSubstitutions -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (IndexSubstitution -> (VName, NameInfo rep)
forall a a rep d. (a, a, LParamInfo rep, d) -> (a, NameInfo rep)
fromSubstitution (IndexSubstitution -> (VName, NameInfo rep))
-> ((VName, IndexSubstitution) -> IndexSubstitution)
-> (VName, IndexSubstitution)
-> (VName, NameInfo rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, IndexSubstitution) -> IndexSubstitution
forall a b. (a, b) -> b
snd)
  where
    fromSubstitution :: (a, a, LParamInfo rep, d) -> (a, NameInfo rep)
fromSubstitution (a
_, a
name, LParamInfo rep
t, d
_) =
      (a
name, LParamInfo rep -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName LParamInfo rep
t)

-- | Perform the substitution.
substituteIndices ::
  ( MonadFreshNames m,
    BuilderOps rep,
    Buildable rep,
    Aliased rep,
    LParamInfo rep ~ Type
  ) =>
  IndexSubstitutions ->
  Stms rep ->
  m (IndexSubstitutions, Stms rep)
substituteIndices :: IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
substituteIndices IndexSubstitutions
substs Stms rep
stms =
  BuilderT rep m IndexSubstitutions
-> Scope rep -> m (IndexSubstitutions, Stms rep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (IndexSubstitutions
-> Stms (Rep (BuilderT rep m)) -> BuilderT rep m IndexSubstitutions
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
substituteIndicesInStms IndexSubstitutions
substs Stms rep
Stms (Rep (BuilderT rep m))
stms) Scope rep
types
  where
    types :: Scope rep
types = IndexSubstitutions -> Scope rep
forall rep.
(LParamInfo rep ~ Type) =>
IndexSubstitutions -> Scope rep
typeEnvFromSubstitutions IndexSubstitutions
substs

substituteIndicesInStms ::
  (MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
  IndexSubstitutions ->
  Stms (Rep m) ->
  m IndexSubstitutions
substituteIndicesInStms :: IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
substituteIndicesInStms = (IndexSubstitutions -> Stm (Rep m) -> m IndexSubstitutions)
-> IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions -> Stm (Rep m) -> m IndexSubstitutions
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stm (Rep m) -> m IndexSubstitutions
substituteIndicesInStm

substituteIndicesInStm ::
  (MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
  IndexSubstitutions ->
  Stm (Rep m) ->
  m IndexSubstitutions
-- FIXME: we likely need to do something similar for all expressions
-- that produce aliases.  Ugh.  See issue #1460.  Or maybe we should
-- look at/copy all consumed arrays up front, instead of ad-hoc.
substituteIndicesInStm :: IndexSubstitutions -> Stm (Rep m) -> m IndexSubstitutions
substituteIndicesInStm IndexSubstitutions
substs (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
_ (BasicOp (Rotate [SubExp]
rots VName
v)))
  | Just (Certs
cs, VName
src, Type
src_t, Slice SubExp
is) <- VName -> IndexSubstitutions -> Maybe IndexSubstitution
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs,
    [VName
v'] <- Pat (Rep m) -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat (Rep m)
pat = do
    VName
src' <-
      String -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v' String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_subst") (ExpT (Rep m) -> m VName) -> ExpT (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (Int -> SubExp -> [SubExp]
forall a. Int -> a -> [a]
replicate (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
src_t Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
rots) SubExp
zero [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
rots) VName
src
    Type
src_t' <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
src'
    IndexSubstitutions -> m IndexSubstitutions
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexSubstitutions -> m IndexSubstitutions)
-> IndexSubstitutions -> m IndexSubstitutions
forall a b. (a -> b) -> a -> b
$ (VName
v', (Certs
cs, VName
src', Type
src_t', Slice SubExp
is)) (VName, IndexSubstitution)
-> IndexSubstitutions -> IndexSubstitutions
forall a. a -> [a] -> [a]
: IndexSubstitutions
substs
  where
    zero :: SubExp
zero = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
substituteIndicesInStm IndexSubstitutions
substs (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
_ (BasicOp (Rearrange [Int]
perm VName
v)))
  | Just (Certs
cs, VName
src, Type
src_t, Slice SubExp
is) <- VName -> IndexSubstitutions -> Maybe IndexSubstitution
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs,
    [VName
v'] <- Pat (Rep m) -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat (Rep m)
pat = do
    let extra_dims :: Int
extra_dims = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
src_t Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm
        perm' :: [Int]
perm' = [Int
0 .. Int
extra_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
extra_dims) [Int]
perm
    VName
src' <-
      String -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v' String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_subst") (ExpT (Rep m) -> m VName) -> ExpT (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
src
    Type
src_t' <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
src'
    IndexSubstitutions -> m IndexSubstitutions
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexSubstitutions -> m IndexSubstitutions)
-> IndexSubstitutions -> m IndexSubstitutions
forall a b. (a -> b) -> a -> b
$ (VName
v', (Certs
cs, VName
src', Type
src_t', Slice SubExp
is)) (VName, IndexSubstitution)
-> IndexSubstitutions -> IndexSubstitutions
forall a. a -> [a] -> [a]
: IndexSubstitutions
substs
substituteIndicesInStm IndexSubstitutions
substs (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
rep ExpT (Rep m)
e) = do
  ExpT (Rep m)
e' <- IndexSubstitutions -> ExpT (Rep m) -> m (ExpT (Rep m))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Exp (Rep m) -> m (Exp (Rep m))
substituteIndicesInExp IndexSubstitutions
substs ExpT (Rep m)
e
  Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> Stm (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep m)
-> StmAux (ExpDec (Rep m)) -> ExpT (Rep m) -> Stm (Rep m)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
rep ExpT (Rep m)
e'
  IndexSubstitutions -> m IndexSubstitutions
forall (f :: * -> *) a. Applicative f => a -> f a
pure IndexSubstitutions
substs

substituteIndicesInExp ::
  (MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
  IndexSubstitutions ->
  Exp (Rep m) ->
  m (Exp (Rep m))
substituteIndicesInExp :: IndexSubstitutions -> Exp (Rep m) -> m (Exp (Rep m))
substituteIndicesInExp IndexSubstitutions
substs (Op Op (Rep m)
op) = do
  let used_in_op :: IndexSubstitutions
used_in_op = ((VName, IndexSubstitution) -> Bool)
-> IndexSubstitutions -> IndexSubstitutions
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Op (Rep m) -> Names
forall a. FreeIn a => a -> Names
freeIn Op (Rep m)
op) (VName -> Bool)
-> ((VName, IndexSubstitution) -> VName)
-> (VName, IndexSubstitution)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, IndexSubstitution) -> VName
forall a b. (a, b) -> a
fst) IndexSubstitutions
substs
  Map VName VName
var_substs <- ([Map VName VName] -> Map VName VName)
-> m [Map VName VName] -> m (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName VName] -> Map VName VName
forall a. Monoid a => [a] -> a
mconcat (m [Map VName VName] -> m (Map VName VName))
-> m [Map VName VName] -> m (Map VName VName)
forall a b. (a -> b) -> a -> b
$
    IndexSubstitutions
-> ((VName, IndexSubstitution) -> m (Map VName VName))
-> m [Map VName VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM IndexSubstitutions
used_in_op (((VName, IndexSubstitution) -> m (Map VName VName))
 -> m [Map VName VName])
-> ((VName, IndexSubstitution) -> m (Map VName VName))
-> m [Map VName VName]
forall a b. (a -> b) -> a -> b
$ \(VName
v, (Certs
cs, VName
src, Type
src_dec, Slice [DimIndex SubExp]
is)) -> do
      VName
v' <-
        Certs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
          String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
src String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_op_idx") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
src (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Type -> Type
forall t. Typed t => t -> Type
typeOf Type
src_dec) [DimIndex SubExp]
is
      Map VName VName -> m (Map VName VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName -> m (Map VName VName))
-> Map VName VName -> m (Map VName VName)
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
v VName
v'
  Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Op (Rep m) -> Op (Rep m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
var_substs Op (Rep m)
op
substituteIndicesInExp IndexSubstitutions
substs Exp (Rep m)
e = do
  IndexSubstitutions
substs' <- Exp (Rep m) -> m IndexSubstitutions
copyAnyConsumed Exp (Rep m)
e
  let substitute :: Mapper (Rep m) (Rep m) m
substitute =
        Mapper (Rep m) (Rep m) m
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
          { mapOnSubExp :: SubExp -> m SubExp
mapOnSubExp = IndexSubstitutions -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions
substs',
            mapOnVName :: VName -> m VName
mapOnVName = IndexSubstitutions -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> VName -> m VName
substituteIndicesInVar IndexSubstitutions
substs',
            mapOnBody :: Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m))
mapOnBody = (Body (Rep m) -> m (Body (Rep m)))
-> Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m))
forall a b. a -> b -> a
const ((Body (Rep m) -> m (Body (Rep m)))
 -> Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m)))
-> (Body (Rep m) -> m (Body (Rep m)))
-> Scope (Rep m)
-> Body (Rep m)
-> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions -> Body (Rep m) -> m (Body (Rep m))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Body (Rep m) -> m (Body (Rep m))
substituteIndicesInBody IndexSubstitutions
substs'
          }

  Mapper (Rep m) (Rep m) m -> Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper (Rep m) (Rep m) m
substitute Exp (Rep m)
e
  where
    copyAnyConsumed :: Exp (Rep m) -> m IndexSubstitutions
copyAnyConsumed =
      let consumingSubst :: IndexSubstitutions -> VName -> m IndexSubstitutions
consumingSubst IndexSubstitutions
substs' VName
v
            | Just (Certs
cs2, VName
src2, Type
src2dec, Slice SubExp
is2) <- VName -> IndexSubstitutions -> Maybe IndexSubstitution
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs = do
              VName
row <-
                Certs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
                  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
src2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Type -> Type
forall t. Typed t => t -> Type
typeOf Type
src2dec) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
is2)
              VName
row_copy <-
                String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row_copy") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
row
              IndexSubstitutions -> m IndexSubstitutions
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexSubstitutions -> m IndexSubstitutions)
-> IndexSubstitutions -> m IndexSubstitutions
forall a b. (a -> b) -> a -> b
$
                VName
-> VName
-> IndexSubstitution
-> IndexSubstitutions
-> IndexSubstitutions
update
                  VName
v
                  VName
v
                  ( Certs
forall a. Monoid a => a
mempty,
                    VName
row_copy,
                    Type
src2dec
                      Type -> Type -> Type
forall a. SetType a => a -> Type -> a
`setType` ( Type -> Type
forall t. Typed t => t -> Type
typeOf Type
src2dec
                                    Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is2
                                ),
                    [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice []
                  )
                  IndexSubstitutions
substs'
          consumingSubst IndexSubstitutions
substs' VName
_ =
            IndexSubstitutions -> m IndexSubstitutions
forall (m :: * -> *) a. Monad m => a -> m a
return IndexSubstitutions
substs'
       in (IndexSubstitutions -> VName -> m IndexSubstitutions)
-> IndexSubstitutions -> [VName] -> m IndexSubstitutions
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions -> VName -> m IndexSubstitutions
consumingSubst IndexSubstitutions
substs ([VName] -> m IndexSubstitutions)
-> (Exp (Rep m) -> [VName]) -> Exp (Rep m) -> m IndexSubstitutions
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> (Exp (Rep m) -> Names) -> Exp (Rep m) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Rep m) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp

substituteIndicesInSubExp ::
  MonadBuilder m =>
  IndexSubstitutions ->
  SubExp ->
  m SubExp
substituteIndicesInSubExp :: IndexSubstitutions -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions
substs (Var VName
v) =
  VName -> SubExp
Var (VName -> SubExp) -> m VName -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IndexSubstitutions -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> VName -> m VName
substituteIndicesInVar IndexSubstitutions
substs VName
v
substituteIndicesInSubExp IndexSubstitutions
_ SubExp
se =
  SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

substituteIndicesInVar ::
  MonadBuilder m =>
  IndexSubstitutions ->
  VName ->
  m VName
substituteIndicesInVar :: IndexSubstitutions -> VName -> m VName
substituteIndicesInVar IndexSubstitutions
substs VName
v
  | Just (Certs
cs2, VName
src2, Type
_, Slice []) <- VName -> IndexSubstitutions -> Maybe IndexSubstitution
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs =
    Certs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
      String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
src2) (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src2
  | Just (Certs
cs2, VName
src2, Type
src2_dec, Slice [DimIndex SubExp]
is2) <- VName -> IndexSubstitutions -> Maybe IndexSubstitution
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs =
    Certs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
      String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
src2 String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_v_idx") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
src2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Type -> Type
forall t. Typed t => t -> Type
typeOf Type
src2_dec) [DimIndex SubExp]
is2
  | Bool
otherwise =
    VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

substituteIndicesInBody ::
  (MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
  IndexSubstitutions ->
  Body (Rep m) ->
  m (Body (Rep m))
substituteIndicesInBody :: IndexSubstitutions -> Body (Rep m) -> m (Body (Rep m))
substituteIndicesInBody IndexSubstitutions
substs (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = do
  (IndexSubstitutions
substs', Stms (Rep m)
stms') <-
    Stms (Rep m)
-> m (IndexSubstitutions, Stms (Rep m))
-> m (IndexSubstitutions, Stms (Rep m))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms (Rep m)
stms (m (IndexSubstitutions, Stms (Rep m))
 -> m (IndexSubstitutions, Stms (Rep m)))
-> m (IndexSubstitutions, Stms (Rep m))
-> m (IndexSubstitutions, Stms (Rep m))
forall a b. (a -> b) -> a -> b
$
      m IndexSubstitutions -> m (IndexSubstitutions, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m IndexSubstitutions -> m (IndexSubstitutions, Stms (Rep m)))
-> m IndexSubstitutions -> m (IndexSubstitutions, Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
substituteIndicesInStms IndexSubstitutions
substs Stms (Rep m)
stms
  (Result
res', Stms (Rep m)
res_stms) <-
    Stms (Rep m)
-> m (Result, Stms (Rep m)) -> m (Result, Stms (Rep m))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms (Rep m)
stms' (m (Result, Stms (Rep m)) -> m (Result, Stms (Rep m)))
-> m (Result, Stms (Rep m)) -> m (Result, Stms (Rep m))
forall a b. (a -> b) -> a -> b
$
      m Result -> m (Result, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m Result -> m (Result, Stms (Rep m)))
-> m Result -> m (Result, Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> m SubExpRes) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IndexSubstitutions -> SubExpRes -> m SubExpRes
forall (f :: * -> *).
MonadBuilder f =>
IndexSubstitutions -> SubExpRes -> f SubExpRes
onSubExpRes IndexSubstitutions
substs') Result
res
  Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM (Stms (Rep m)
stms' Stms (Rep m) -> Stms (Rep m) -> Stms (Rep m)
forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
res_stms) Result
res'
  where
    onSubExpRes :: IndexSubstitutions -> SubExpRes -> f SubExpRes
onSubExpRes IndexSubstitutions
substs' (SubExpRes Certs
cs SubExp
se) =
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes) -> f SubExp -> f SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IndexSubstitutions -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions
substs' SubExp
se

update ::
  VName ->
  VName ->
  IndexSubstitution ->
  IndexSubstitutions ->
  IndexSubstitutions
update :: VName
-> VName
-> IndexSubstitution
-> IndexSubstitutions
-> IndexSubstitutions
update VName
needle VName
name IndexSubstitution
subst ((VName
othername, IndexSubstitution
othersubst) : IndexSubstitutions
substs)
  | VName
needle VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
othername = (VName
name, IndexSubstitution
subst) (VName, IndexSubstitution)
-> IndexSubstitutions -> IndexSubstitutions
forall a. a -> [a] -> [a]
: IndexSubstitutions
substs
  | Bool
otherwise = (VName
othername, IndexSubstitution
othersubst) (VName, IndexSubstitution)
-> IndexSubstitutions -> IndexSubstitutions
forall a. a -> [a] -> [a]
: VName
-> VName
-> IndexSubstitution
-> IndexSubstitutions
-> IndexSubstitutions
update VName
needle VName
name IndexSubstitution
subst IndexSubstitutions
substs
update VName
needle VName
_ IndexSubstitution
_ [] = String -> IndexSubstitutions
forall a. HasCallStack => String -> a
error (String -> IndexSubstitutions) -> String -> IndexSubstitutions
forall a b. (a -> b) -> a -> b
$ String
"Cannot find substitution for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
needle