{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.SubstituteIndices
( substituteIndices,
IndexSubstitution,
IndexSubstitutions,
)
where
import Control.Monad
import qualified Data.Map.Strict as M
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Transform.Substitute
import Futhark.Util
type IndexSubstitution dec = (Certificates, VName, dec, Slice SubExp)
type IndexSubstitutions dec = [(VName, IndexSubstitution dec)]
typeEnvFromSubstitutions ::
LetDec rep ~ dec =>
IndexSubstitutions dec ->
Scope rep
typeEnvFromSubstitutions :: forall rep dec.
(LetDec rep ~ dec) =>
IndexSubstitutions dec -> Scope rep
typeEnvFromSubstitutions = [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Map VName (NameInfo rep))
-> ([(VName, IndexSubstitution dec)] -> [(VName, NameInfo rep)])
-> [(VName, IndexSubstitution dec)]
-> Map VName (NameInfo rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, IndexSubstitution dec) -> (VName, NameInfo rep))
-> [(VName, IndexSubstitution dec)] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (IndexSubstitution dec -> (VName, NameInfo rep)
forall {a} {a} {rep} {d}.
(a, a, LetDec rep, d) -> (a, NameInfo rep)
fromSubstitution (IndexSubstitution dec -> (VName, NameInfo rep))
-> ((VName, IndexSubstitution dec) -> IndexSubstitution dec)
-> (VName, IndexSubstitution dec)
-> (VName, NameInfo rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, IndexSubstitution dec) -> IndexSubstitution dec
forall a b. (a, b) -> b
snd)
where
fromSubstitution :: (a, a, LetDec rep, d) -> (a, NameInfo rep)
fromSubstitution (a
_, a
name, LetDec rep
t, d
_) =
(a
name, LetDec rep -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName LetDec rep
t)
substituteIndices ::
( MonadFreshNames m,
BinderOps rep,
Bindable rep,
Aliased rep,
LetDec rep ~ dec
) =>
IndexSubstitutions dec ->
Stms rep ->
m (IndexSubstitutions dec, Stms rep)
substituteIndices :: forall (m :: * -> *) rep dec.
(MonadFreshNames m, BinderOps rep, Bindable rep, Aliased rep,
LetDec rep ~ dec) =>
IndexSubstitutions dec
-> Stms rep -> m (IndexSubstitutions dec, Stms rep)
substituteIndices IndexSubstitutions dec
substs Stms rep
bnds =
BinderT rep m (IndexSubstitutions dec)
-> Scope rep -> m (IndexSubstitutions dec, Stms rep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (IndexSubstitutions (LetDec (Rep (BinderT rep m)))
-> Stms (Rep (BinderT rep m))
-> BinderT
rep m (IndexSubstitutions (LetDec (Rep (BinderT rep m))))
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Stms (Rep m) -> m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStms IndexSubstitutions dec
IndexSubstitutions (LetDec (Rep (BinderT rep m)))
substs Stms rep
Stms (Rep (BinderT rep m))
bnds) Scope rep
types
where
types :: Scope rep
types = IndexSubstitutions dec -> Scope rep
forall rep dec.
(LetDec rep ~ dec) =>
IndexSubstitutions dec -> Scope rep
typeEnvFromSubstitutions IndexSubstitutions dec
substs
substituteIndicesInStms ::
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m)) ->
Stms (Rep m) ->
m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStms :: forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Stms (Rep m) -> m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStms = ([(VName,
(Certificates, VName, LetDec (Rep m), [DimIndex SubExp]))]
-> Stm (Rep m)
-> m [(VName,
(Certificates, VName, LetDec (Rep m), [DimIndex SubExp]))])
-> [(VName,
(Certificates, VName, LetDec (Rep m), [DimIndex SubExp]))]
-> Seq (Stm (Rep m))
-> m [(VName,
(Certificates, VName, LetDec (Rep m), [DimIndex SubExp]))]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM [(VName, (Certificates, VName, LetDec (Rep m), [DimIndex SubExp]))]
-> Stm (Rep m)
-> m [(VName,
(Certificates, VName, LetDec (Rep m), [DimIndex SubExp]))]
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Stm (Rep m) -> m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStm
substituteIndicesInStm ::
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m)) ->
Stm (Rep m) ->
m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStm :: forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Stm (Rep m) -> m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStm IndexSubstitutions (LetDec (Rep m))
substs (Let Pattern (Rep m)
pat StmAux (ExpDec (Rep m))
rep Exp (Rep m)
e) = do
Exp (Rep m)
e' <- IndexSubstitutions (LetDec (Rep m))
-> Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) dec.
(MonadBinder m, Bindable (Rep m), Aliased (Rep m),
LetDec (Rep m) ~ dec) =>
IndexSubstitutions (LetDec (Rep m))
-> Exp (Rep m) -> m (Exp (Rep m))
substituteIndicesInExp IndexSubstitutions (LetDec (Rep m))
substs Exp (Rep m)
e
(IndexSubstitutions (LetDec (Rep m))
substs', Pattern (Rep m)
pat') <- IndexSubstitutions (LetDec (Rep m))
-> Pattern (Rep m)
-> m (IndexSubstitutions (LetDec (Rep m)), Pattern (Rep m))
forall (m :: * -> *) dec.
(MonadBinder m, LetDec (Rep m) ~ dec) =>
IndexSubstitutions (LetDec (Rep m))
-> PatternT dec
-> m (IndexSubstitutions (LetDec (Rep m)), PatternT dec)
substituteIndicesInPattern IndexSubstitutions (LetDec (Rep m))
substs Pattern (Rep m)
pat
Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> Stm (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep m)
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern (Rep m)
pat' StmAux (ExpDec (Rep m))
rep Exp (Rep m)
e'
IndexSubstitutions (LetDec (Rep m))
-> m (IndexSubstitutions (LetDec (Rep m)))
forall (m :: * -> *) a. Monad m => a -> m a
return IndexSubstitutions (LetDec (Rep m))
substs'
substituteIndicesInPattern ::
(MonadBinder m, LetDec (Rep m) ~ dec) =>
IndexSubstitutions (LetDec (Rep m)) ->
PatternT dec ->
m (IndexSubstitutions (LetDec (Rep m)), PatternT dec)
substituteIndicesInPattern :: forall (m :: * -> *) dec.
(MonadBinder m, LetDec (Rep m) ~ dec) =>
IndexSubstitutions (LetDec (Rep m))
-> PatternT dec
-> m (IndexSubstitutions (LetDec (Rep m)), PatternT dec)
substituteIndicesInPattern IndexSubstitutions (LetDec (Rep m))
substs PatternT dec
pat = do
(IndexSubstitutions dec
substs', [PatElemT dec]
context) <- (IndexSubstitutions dec
-> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec))
-> IndexSubstitutions dec
-> [PatElemT dec]
-> m (IndexSubstitutions dec, [PatElemT dec])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM IndexSubstitutions dec
-> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec)
forall {m :: * -> *} {a} {b}. Monad m => a -> b -> m (a, b)
sub IndexSubstitutions dec
IndexSubstitutions (LetDec (Rep m))
substs ([PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec]))
-> [PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec])
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat
(IndexSubstitutions dec
substs'', [PatElemT dec]
values) <- (IndexSubstitutions dec
-> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec))
-> IndexSubstitutions dec
-> [PatElemT dec]
-> m (IndexSubstitutions dec, [PatElemT dec])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM IndexSubstitutions dec
-> PatElemT dec -> m (IndexSubstitutions dec, PatElemT dec)
forall {m :: * -> *} {a} {b}. Monad m => a -> b -> m (a, b)
sub IndexSubstitutions dec
substs' ([PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec]))
-> [PatElemT dec] -> m (IndexSubstitutions dec, [PatElemT dec])
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat
(IndexSubstitutions dec, PatternT dec)
-> m (IndexSubstitutions dec, PatternT dec)
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexSubstitutions dec
substs'', [PatElemT dec] -> [PatElemT dec] -> PatternT dec
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT dec]
context [PatElemT dec]
values)
where
sub :: a -> b -> m (a, b)
sub a
substs' b
patElem = (a, b) -> m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
substs', b
patElem)
substituteIndicesInExp ::
( MonadBinder m,
Bindable (Rep m),
Aliased (Rep m),
LetDec (Rep m) ~ dec
) =>
IndexSubstitutions (LetDec (Rep m)) ->
Exp (Rep m) ->
m (Exp (Rep m))
substituteIndicesInExp :: forall (m :: * -> *) dec.
(MonadBinder m, Bindable (Rep m), Aliased (Rep m),
LetDec (Rep m) ~ dec) =>
IndexSubstitutions (LetDec (Rep m))
-> Exp (Rep m) -> m (Exp (Rep m))
substituteIndicesInExp IndexSubstitutions (LetDec (Rep m))
substs (Op Op (Rep m)
op) = do
let used_in_op :: [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
used_in_op = ((VName, (Certificates, VName, dec, [DimIndex SubExp])) -> Bool)
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Op (Rep m) -> Names
forall a. FreeIn a => a -> Names
freeIn Op (Rep m)
op) (VName -> Bool)
-> ((VName, (Certificates, VName, dec, [DimIndex SubExp]))
-> VName)
-> (VName, (Certificates, VName, dec, [DimIndex SubExp]))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, (Certificates, VName, dec, [DimIndex SubExp])) -> VName
forall a b. (a, b) -> a
fst) [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
IndexSubstitutions (LetDec (Rep m))
substs
Map VName VName
var_substs <- ([Map VName VName] -> Map VName VName)
-> m [Map VName VName] -> m (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName VName] -> Map VName VName
forall a. Monoid a => [a] -> a
mconcat (m [Map VName VName] -> m (Map VName VName))
-> m [Map VName VName] -> m (Map VName VName)
forall a b. (a -> b) -> a -> b
$
[(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> ((VName, (Certificates, VName, dec, [DimIndex SubExp]))
-> m (Map VName VName))
-> m [Map VName VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
used_in_op (((VName, (Certificates, VName, dec, [DimIndex SubExp]))
-> m (Map VName VName))
-> m [Map VName VName])
-> ((VName, (Certificates, VName, dec, [DimIndex SubExp]))
-> m (Map VName VName))
-> m [Map VName VName]
forall a b. (a -> b) -> a -> b
$ \(VName
v, (Certificates
cs, VName
src, dec
src_dec, [DimIndex SubExp]
is)) -> do
VName
v' <-
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"idx" (ExpT (Rep m) -> m VName) -> ExpT (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
src ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (dec -> Type
forall t. Typed t => t -> Type
typeOf dec
src_dec) [DimIndex SubExp]
is
Map VName VName -> m (Map VName VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName -> m (Map VName VName))
-> Map VName VName -> m (Map VName VName)
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
v VName
v'
ExpT (Rep m) -> m (ExpT (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Rep m) -> m (ExpT (Rep m)))
-> ExpT (Rep m) -> m (ExpT (Rep m))
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> ExpT (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> ExpT (Rep m)) -> Op (Rep m) -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Op (Rep m) -> Op (Rep m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
var_substs Op (Rep m)
op
substituteIndicesInExp IndexSubstitutions (LetDec (Rep m))
substs ExpT (Rep m)
e = do
[(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
substs' <- ExpT (Rep m)
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
copyAnyConsumed ExpT (Rep m)
e
let substitute :: Mapper (Rep m) (Rep m) m
substitute =
Mapper (Rep m) (Rep m) m
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
{ mapOnSubExp :: SubExp -> m SubExp
mapOnSubExp = IndexSubstitutions (LetDec (Rep m)) -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) -> SubExp -> m SubExp
substituteIndicesInSubExp [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
IndexSubstitutions (LetDec (Rep m))
substs',
mapOnVName :: VName -> m VName
mapOnVName = IndexSubstitutions (LetDec (Rep m)) -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) -> VName -> m VName
substituteIndicesInVar [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
IndexSubstitutions (LetDec (Rep m))
substs',
mapOnBody :: Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m))
mapOnBody = (Body (Rep m) -> m (Body (Rep m)))
-> Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m))
forall a b. a -> b -> a
const ((Body (Rep m) -> m (Body (Rep m)))
-> Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m)))
-> (Body (Rep m) -> m (Body (Rep m)))
-> Scope (Rep m)
-> Body (Rep m)
-> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions (LetDec (Rep m))
-> Body (Rep m) -> m (Body (Rep m))
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Body (Rep m) -> m (Body (Rep m))
substituteIndicesInBody [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
IndexSubstitutions (LetDec (Rep m))
substs'
}
Mapper (Rep m) (Rep m) m -> ExpT (Rep m) -> m (ExpT (Rep m))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper (Rep m) (Rep m) m
substitute ExpT (Rep m)
e
where
copyAnyConsumed :: ExpT (Rep m)
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
copyAnyConsumed =
let consumingSubst :: [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> VName
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
consumingSubst [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
substs' VName
v
| Just (Certificates
cs2, VName
src2, dec
src2dec, [DimIndex SubExp]
is2) <- VName
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> Maybe (Certificates, VName, dec, [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
IndexSubstitutions (LetDec (Rep m))
substs = do
VName
row <-
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row") (ExpT (Rep m) -> m VName) -> ExpT (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
src2 ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (dec -> Type
forall t. Typed t => t -> Type
typeOf dec
src2dec) [DimIndex SubExp]
is2
VName
row_copy <-
String -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row_copy") (ExpT (Rep m) -> m VName) -> ExpT (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
row
[(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))])
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> (Certificates, VName, dec, [DimIndex SubExp])
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall dec.
VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
update
VName
v
VName
v
( Certificates
forall a. Monoid a => a
mempty,
VName
row_copy,
dec
src2dec
dec -> Type -> dec
forall a. SetType a => a -> Type -> a
`setType` ( dec -> Type
forall t. Typed t => t -> Type
typeOf dec
src2dec
Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` [DimIndex SubExp] -> [SubExp]
forall d. Slice d -> [d]
sliceDims [DimIndex SubExp]
is2
),
[]
)
[(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
substs'
consumingSubst [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
substs' VName
_ =
[(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall (m :: * -> *) a. Monad m => a -> m a
return [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
substs'
in ([(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> VName
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))])
-> [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> [VName]
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
-> VName
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
consumingSubst [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
IndexSubstitutions (LetDec (Rep m))
substs ([VName]
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))])
-> (ExpT (Rep m) -> [VName])
-> ExpT (Rep m)
-> m [(VName, (Certificates, VName, dec, [DimIndex SubExp]))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> (ExpT (Rep m) -> Names) -> ExpT (Rep m) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpT (Rep m) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp
substituteIndicesInSubExp ::
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) ->
SubExp ->
m SubExp
substituteIndicesInSubExp :: forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions (LetDec (Rep m))
substs (Var VName
v) =
VName -> SubExp
Var (VName -> SubExp) -> m VName -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IndexSubstitutions (LetDec (Rep m)) -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions (LetDec (Rep m))
substs VName
v
substituteIndicesInSubExp IndexSubstitutions (LetDec (Rep m))
_ SubExp
se =
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
substituteIndicesInVar ::
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) ->
VName ->
m VName
substituteIndicesInVar :: forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions (LetDec (Rep m))
substs VName
v
| Just (Certificates
cs2, VName
src2, LetDec (Rep m)
_, []) <- VName
-> IndexSubstitutions (LetDec (Rep m))
-> Maybe (Certificates, VName, LetDec (Rep m), [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions (LetDec (Rep m))
substs =
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
src2) (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src2
| Just (Certificates
cs2, VName
src2, LetDec (Rep m)
src2_dec, [DimIndex SubExp]
is2) <- VName
-> IndexSubstitutions (LetDec (Rep m))
-> Maybe (Certificates, VName, LetDec (Rep m), [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions (LetDec (Rep m))
substs =
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"idx" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
src2 ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (LetDec (Rep m) -> Type
forall t. Typed t => t -> Type
typeOf LetDec (Rep m)
src2_dec) [DimIndex SubExp]
is2
| Bool
otherwise =
VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
substituteIndicesInBody ::
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m)) ->
Body (Rep m) ->
m (Body (Rep m))
substituteIndicesInBody :: forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Body (Rep m) -> m (Body (Rep m))
substituteIndicesInBody IndexSubstitutions (LetDec (Rep m))
substs (Body BodyDec (Rep m)
_ Stms (Rep m)
stms [SubExp]
res) = do
(IndexSubstitutions (LetDec (Rep m))
substs', Stms (Rep m)
stms') <-
Stms (Rep m)
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms (Rep m)
stms (m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m)))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
forall a b. (a -> b) -> a -> b
$
m (IndexSubstitutions (LetDec (Rep m)))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (m (IndexSubstitutions (LetDec (Rep m)))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m)))
-> m (IndexSubstitutions (LetDec (Rep m)))
-> m (IndexSubstitutions (LetDec (Rep m)), Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions (LetDec (Rep m))
-> Stms (Rep m) -> m (IndexSubstitutions (LetDec (Rep m)))
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions (LetDec (Rep m))
-> Stms (Rep m) -> m (IndexSubstitutions (LetDec (Rep m)))
substituteIndicesInStms IndexSubstitutions (LetDec (Rep m))
substs Stms (Rep m)
stms
([SubExp]
res', Stms (Rep m)
res_stms) <-
Stms (Rep m)
-> m ([SubExp], Stms (Rep m)) -> m ([SubExp], Stms (Rep m))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms (Rep m)
stms' (m ([SubExp], Stms (Rep m)) -> m ([SubExp], Stms (Rep m)))
-> m ([SubExp], Stms (Rep m)) -> m ([SubExp], Stms (Rep m))
forall a b. (a -> b) -> a -> b
$
m [SubExp] -> m ([SubExp], Stms (Rep m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (m [SubExp] -> m ([SubExp], Stms (Rep m)))
-> m [SubExp] -> m ([SubExp], Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IndexSubstitutions (LetDec (Rep m)) -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetDec (Rep m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions (LetDec (Rep m))
substs') [SubExp]
res
Stms (Rep m) -> [SubExp] -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Rep m) -> [SubExp] -> m (Body (Rep m))
mkBodyM (Stms (Rep m)
stms' Stms (Rep m) -> Stms (Rep m) -> Stms (Rep m)
forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
res_stms) [SubExp]
res'
update ::
VName ->
VName ->
IndexSubstitution dec ->
IndexSubstitutions dec ->
IndexSubstitutions dec
update :: forall dec.
VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
update VName
needle VName
name IndexSubstitution dec
subst ((VName
othername, IndexSubstitution dec
othersubst) : [(VName, IndexSubstitution dec)]
substs)
| VName
needle VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
othername = (VName
name, IndexSubstitution dec
subst) (VName, IndexSubstitution dec)
-> [(VName, IndexSubstitution dec)]
-> [(VName, IndexSubstitution dec)]
forall a. a -> [a] -> [a]
: [(VName, IndexSubstitution dec)]
substs
| Bool
otherwise = (VName
othername, IndexSubstitution dec
othersubst) (VName, IndexSubstitution dec)
-> [(VName, IndexSubstitution dec)]
-> [(VName, IndexSubstitution dec)]
forall a. a -> [a] -> [a]
: VName
-> VName
-> IndexSubstitution dec
-> [(VName, IndexSubstitution dec)]
-> [(VName, IndexSubstitution dec)]
forall dec.
VName
-> VName
-> IndexSubstitution dec
-> IndexSubstitutions dec
-> IndexSubstitutions dec
update VName
needle VName
name IndexSubstitution dec
subst [(VName, IndexSubstitution dec)]
substs
update VName
needle VName
_ IndexSubstitution dec
_ [] = String -> [(VName, IndexSubstitution dec)]
forall a. HasCallStack => String -> a
error (String -> [(VName, IndexSubstitution dec)])
-> String -> [(VName, IndexSubstitution dec)]
forall a b. (a -> b) -> a -> b
$ String
"Cannot find substitution for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
needle