{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.SubstituteIndices
( substituteIndices,
IndexSubstitution,
IndexSubstitutions,
)
where
import Control.Monad
import Data.Map.Strict qualified as M
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Transform.Substitute
type IndexSubstitution = (Certs, VName, Type, Slice SubExp)
type IndexSubstitutions = [(VName, IndexSubstitution)]
typeEnvFromSubstitutions :: LParamInfo rep ~ Type => IndexSubstitutions -> Scope rep
typeEnvFromSubstitutions :: forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
IndexSubstitutions -> Scope rep
typeEnvFromSubstitutions = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall {k} {a} {a} {rep :: k} {d}.
(a, a, LParamInfo rep, d) -> (a, NameInfo rep)
fromSubstitution forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
where
fromSubstitution :: (a, a, LParamInfo rep, d) -> (a, NameInfo rep)
fromSubstitution (a
_, a
name, LParamInfo rep
t, d
_) =
(a
name, forall {k} (rep :: k). LParamInfo rep -> NameInfo rep
LParamName LParamInfo rep
t)
substituteIndices ::
( MonadFreshNames m,
BuilderOps rep,
Buildable rep,
Aliased rep
) =>
IndexSubstitutions ->
Stms rep ->
m (IndexSubstitutions, Stms rep)
substituteIndices :: forall (m :: * -> *) rep.
(MonadFreshNames m, BuilderOps rep, Buildable rep, Aliased rep) =>
IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
substituteIndices IndexSubstitutions
substs Stms rep
stms =
forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
substituteIndicesInStms IndexSubstitutions
substs Stms rep
stms) Scope rep
types
where
types :: Scope rep
types = forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
IndexSubstitutions -> Scope rep
typeEnvFromSubstitutions IndexSubstitutions
substs
substituteIndicesInStms ::
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions ->
Stms (Rep m) ->
m IndexSubstitutions
substituteIndicesInStms :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
substituteIndicesInStms = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stm (Rep m) -> m IndexSubstitutions
substituteIndicesInStm
substituteIndicesInStm ::
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions ->
Stm (Rep m) ->
m IndexSubstitutions
substituteIndicesInStm :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stm (Rep m) -> m IndexSubstitutions
substituteIndicesInStm IndexSubstitutions
substs (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
_ (BasicOp (Rotate [SubExp]
rots VName
v)))
| Just (Certs
cs, VName
src, Type
src_t, Slice SubExp
is) <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs,
[VName
v'] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Rep m))
pat = do
VName
src' <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v' forall a. Semigroup a => a -> a -> a
<> [Char]
"_subst") forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
[SubExp] -> VName -> BasicOp
Rotate (forall a. Int -> a -> [a]
replicate (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
src_t forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
rots) SubExp
zero forall a. [a] -> [a] -> [a]
++ [SubExp]
rots) VName
src
Type
src_t' <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
src'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ (VName
v', (Certs
cs, VName
src', Type
src_t', Slice SubExp
is)) forall a. a -> [a] -> [a]
: IndexSubstitutions
substs
where
zero :: SubExp
zero = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
substituteIndicesInStm IndexSubstitutions
substs (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
_ (BasicOp (Rearrange [Int]
perm VName
v)))
| Just (Certs
cs, VName
src, Type
src_t, Slice SubExp
is) <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs,
[VName
v'] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Rep m))
pat = do
let extra_dims :: Int
extra_dims = forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
src_t forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm
perm' :: [Int]
perm' = [Int
0 .. Int
extra_dims forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
extra_dims) [Int]
perm
VName
src' <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v' forall a. Semigroup a => a -> a -> a
<> [Char]
"_subst") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
src
Type
src_t' <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
src'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ (VName
v', (Certs
cs, VName
src', Type
src_t', Slice SubExp
is)) forall a. a -> [a] -> [a]
: IndexSubstitutions
substs
substituteIndicesInStm IndexSubstitutions
substs (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
rep Exp (Rep m)
e) = do
Exp (Rep m)
e' <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Exp (Rep m) -> m (Exp (Rep m))
substituteIndicesInExp IndexSubstitutions
substs Exp (Rep m)
e
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
rep Exp (Rep m)
e'
forall (f :: * -> *) a. Applicative f => a -> f a
pure IndexSubstitutions
substs
substituteIndicesInExp ::
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions ->
Exp (Rep m) ->
m (Exp (Rep m))
substituteIndicesInExp :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Exp (Rep m) -> m (Exp (Rep m))
substituteIndicesInExp IndexSubstitutions
substs (Op Op (Rep m)
op) = do
let used_in_op :: IndexSubstitutions
used_in_op = forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Op (Rep m)
op) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) IndexSubstitutions
substs
Map VName VName
var_substs <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM IndexSubstitutions
used_in_op forall a b. (a -> b) -> a -> b
$ \(VName
v, (Certs
cs, VName
src, Type
src_dec, Slice [DimIndex SubExp]
is)) -> do
VName
v' <-
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
src forall a. Semigroup a => a -> a -> a
<> [Char]
"_op_idx") forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
src forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf Type
src_dec) [DimIndex SubExp]
is
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
v VName
v'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
var_substs Op (Rep m)
op
substituteIndicesInExp IndexSubstitutions
substs Exp (Rep m)
e = do
IndexSubstitutions
substs' <- Exp (Rep m) -> m IndexSubstitutions
copyAnyConsumed Exp (Rep m)
e
let substitute :: Mapper (Rep m) (Rep m) m
substitute =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnSubExp :: SubExp -> m SubExp
mapOnSubExp = forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions
substs',
mapOnVName :: VName -> m VName
mapOnVName = forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> VName -> m VName
substituteIndicesInVar IndexSubstitutions
substs',
mapOnBody :: Scope (Rep m) -> Body (Rep m) -> m (Body (Rep m))
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Body (Rep m) -> m (Body (Rep m))
substituteIndicesInBody IndexSubstitutions
substs'
}
forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper (Rep m) (Rep m) m
substitute Exp (Rep m)
e
where
copyAnyConsumed :: Exp (Rep m) -> m IndexSubstitutions
copyAnyConsumed =
let consumingSubst :: IndexSubstitutions -> VName -> m IndexSubstitutions
consumingSubst IndexSubstitutions
substs' VName
v
| Just (Certs
cs2, VName
src2, Type
src2dec, Slice SubExp
is2) <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs = do
VName
row <-
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs2 forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. [a] -> [a] -> [a]
++ [Char]
"_row") forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
src2 forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf Type
src2dec) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
is2)
VName
row_copy <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. [a] -> [a] -> [a]
++ [Char]
"_row_copy") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
row
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> IndexSubstitution
-> IndexSubstitutions
-> IndexSubstitutions
update
VName
v
VName
v
( forall a. Monoid a => a
mempty,
VName
row_copy,
Type
src2dec
forall a. SetType a => a -> Type -> a
`setType` ( forall t. Typed t => t -> Type
typeOf Type
src2dec
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` forall d. Slice d -> [d]
sliceDims Slice SubExp
is2
),
forall d. [DimIndex d] -> Slice d
Slice []
)
IndexSubstitutions
substs'
consumingSubst IndexSubstitutions
substs' VName
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure IndexSubstitutions
substs'
in forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions -> VName -> m IndexSubstitutions
consumingSubst IndexSubstitutions
substs forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Aliased rep => Exp rep -> Names
consumedInExp
substituteIndicesInSubExp ::
MonadBuilder m =>
IndexSubstitutions ->
SubExp ->
m SubExp
substituteIndicesInSubExp :: forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions
substs (Var VName
v) =
VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> VName -> m VName
substituteIndicesInVar IndexSubstitutions
substs VName
v
substituteIndicesInSubExp IndexSubstitutions
_ SubExp
se =
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
substituteIndicesInVar ::
MonadBuilder m =>
IndexSubstitutions ->
VName ->
m VName
substituteIndicesInVar :: forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> VName -> m VName
substituteIndicesInVar IndexSubstitutions
substs VName
v
| Just (Certs
cs2, VName
src2, Type
_, Slice []) <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs =
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs2 forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
src2) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
VName -> SubExp
Var VName
src2
| Just (Certs
cs2, VName
src2, Type
src2_dec, Slice [DimIndex SubExp]
is2) <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions
substs =
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs2 forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
src2 forall a. Semigroup a => a -> a -> a
<> [Char]
"_v_idx") forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
src2 forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf Type
src2_dec) [DimIndex SubExp]
is2
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
substituteIndicesInBody ::
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions ->
Body (Rep m) ->
m (Body (Rep m))
substituteIndicesInBody :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Body (Rep m) -> m (Body (Rep m))
substituteIndicesInBody IndexSubstitutions
substs (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = do
(IndexSubstitutions
substs', Stms (Rep m)
stms') <-
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) =>
IndexSubstitutions -> Stms (Rep m) -> m IndexSubstitutions
substituteIndicesInStms IndexSubstitutions
substs Stms (Rep m)
stms
(Result
res', Stms (Rep m)
res_stms) <-
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms (Rep m)
stms' forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {f :: * -> *}.
MonadBuilder f =>
IndexSubstitutions -> SubExpRes -> f SubExpRes
onSubExpRes IndexSubstitutions
substs') Result
res
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM (Stms (Rep m)
stms' forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
res_stms) Result
res'
where
onSubExpRes :: IndexSubstitutions -> SubExpRes -> f SubExpRes
onSubExpRes IndexSubstitutions
substs' (SubExpRes Certs
cs SubExp
se) =
Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
IndexSubstitutions -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions
substs' SubExp
se
update ::
VName ->
VName ->
IndexSubstitution ->
IndexSubstitutions ->
IndexSubstitutions
update :: VName
-> VName
-> IndexSubstitution
-> IndexSubstitutions
-> IndexSubstitutions
update VName
needle VName
name IndexSubstitution
subst ((VName
othername, IndexSubstitution
othersubst) : IndexSubstitutions
substs)
| VName
needle forall a. Eq a => a -> a -> Bool
== VName
othername = (VName
name, IndexSubstitution
subst) forall a. a -> [a] -> [a]
: IndexSubstitutions
substs
| Bool
otherwise = (VName
othername, IndexSubstitution
othersubst) forall a. a -> [a] -> [a]
: VName
-> VName
-> IndexSubstitution
-> IndexSubstitutions
-> IndexSubstitutions
update VName
needle VName
name IndexSubstitution
subst IndexSubstitutions
substs
update VName
needle VName
_ IndexSubstitution
_ [] = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot find substitution for " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
needle