{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
( lowerUpdateGPU,
lowerUpdate,
LowerUpdate,
DesiredUpdate (..),
)
where
import Control.Monad
import Control.Monad.Writer
import Data.Either
import Data.List (find, unzip4)
import Data.Maybe (isNothing, mapMaybe)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.Optimise.InPlaceLowering.SubstituteIndices
data DesiredUpdate dec = DesiredUpdate
{
forall dec. DesiredUpdate dec -> VName
updateName :: VName,
forall dec. DesiredUpdate dec -> dec
updateType :: dec,
forall dec. DesiredUpdate dec -> Certs
updateCerts :: Certs,
forall dec. DesiredUpdate dec -> VName
updateSource :: VName,
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp,
forall dec. DesiredUpdate dec -> VName
updateValue :: VName
}
deriving (Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DesiredUpdate dec] -> ShowS
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
show :: DesiredUpdate dec -> String
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
Show)
instance Functor DesiredUpdate where
a -> b
f fmap :: forall a b. (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u {updateType :: b
updateType = a -> b
f forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u}
updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
name ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue
type LowerUpdate rep m =
Scope (Aliases rep) ->
Stm (Aliases rep) ->
[DesiredUpdate (LetDec (Aliases rep))] ->
Maybe (m [Stm (Aliases rep)])
lowerUpdate ::
( MonadFreshNames m,
Buildable rep,
LetDec rep ~ Type,
AliasableRep rep
) =>
LowerUpdate rep m
lowerUpdate :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
AliasableRep rep) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases rep)
scope (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body)) [DesiredUpdate (LetDec (Aliases rep))]
updates = do
m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
[(Param DeclType, SubExp)], Body (Aliases rep))
canDo <- forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
Body rep))
lowerUpdateIntoLoop Scope (Aliases rep)
scope [DesiredUpdate (LetDec (Aliases rep))]
updates Pat (LetDec (Aliases rep))
pat [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
([Stm (Aliases rep)]
prestms, [Stm (Aliases rep)]
poststms, [Ident]
pat', [(Param DeclType, SubExp)]
merge', Body (Aliases rep)
body') <- m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
[(Param DeclType, SubExp)], Body (Aliases rep))
canDo
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
[Stm (Aliases rep)]
prestms
forall a. [a] -> [a] -> [a]
++ [ forall rep. Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Aliases rep))
aux) forall a b. (a -> b) -> a -> b
$
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
pat' forall a b. (a -> b) -> a -> b
$
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge' LoopForm (Aliases rep)
form Body (Aliases rep)
body'
]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases rep)]
poststms
lowerUpdate
Scope (Aliases rep)
_
(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (SubExp (Var VName
v))))
[DesiredUpdate VName
bindee_nm LetDec (Aliases rep)
bindee_dec Certs
cs VName
src (Slice [DimIndex SubExp]
is) VName
val]
| forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Aliases rep))
pat forall a. Eq a => a -> a -> Bool
== [VName
src] =
let is' :: Slice SubExp
is' = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf LetDec (Aliases rep)
bindee_dec) [DimIndex SubExp]
is
in forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
[ forall rep. Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Aliases rep))
aux forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall a b. (a -> b) -> a -> b
$
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
bindee_nm forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf LetDec (Aliases rep)
bindee_dec] forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
v Slice SubExp
is' forall a b. (a -> b) -> a -> b
$
VName -> SubExp
Var VName
val
]
lowerUpdate Scope (Aliases rep)
_ Stm (Aliases rep)
_ [DesiredUpdate (LetDec (Aliases rep))]
_ =
forall a. Maybe a
Nothing
lowerUpdateGPU :: MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU :: forall (m :: * -> *). MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU
Scope (Aliases GPU)
scope
(Let Pat (LetDec (Aliases GPU))
pat StmAux (ExpDec (Aliases GPU))
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody))))
[DesiredUpdate (LetDec (Aliases GPU))]
updates
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Aliases GPU))
pat) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (LetDec (Aliases GPU))]
updates,
Bool -> Bool
not Bool
source_used_in_kbody = do
m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
mk <- forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
(m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
space KernelBody (Aliases GPU)
kbody
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
(Pat (VarAliases, Type)
pat', KernelBody (Aliases GPU)
kbody', Stms (Aliases GPU)
poststms) <- m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
mk
let cs :: Certs
cs = forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Aliases GPU))
aux forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall dec. DesiredUpdate dec -> Certs
updateCerts [DesiredUpdate (LetDec (Aliases GPU))]
updates
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (VarAliases, Type)
pat' StmAux (ExpDec (Aliases GPU))
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody')
forall a. a -> [a] -> [a]
: forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Aliases GPU)
poststms
where
source_used_in_kbody :: Bool
source_used_in_kbody =
forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) (Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases GPU)
kbody)))
Names -> Names -> Bool
`namesIntersect` forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map ((forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateSource) [DesiredUpdate (LetDec (Aliases GPU))]
updates)
lowerUpdateGPU Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates = forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
AliasableRep rep) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates
lowerUpdatesIntoSegMap ::
MonadFreshNames m =>
Scope (Aliases GPU) ->
Pat (LetDec (Aliases GPU)) ->
[DesiredUpdate (LetDec (Aliases GPU))] ->
SegSpace ->
KernelBody (Aliases GPU) ->
Maybe
( m
( Pat (LetDec (Aliases GPU)),
KernelBody (Aliases GPU),
Stms (Aliases GPU)
)
)
lowerUpdatesIntoSegMap :: forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
(m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
kspace KernelBody (Aliases GPU)
kbody = do
[m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
mk <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
onRet (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases GPU))
pat) (forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases GPU)
kbody)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
([PatElem (VarAliases, Type)]
pes, [Stms (Aliases GPU)]
bodystms, [KernelResult]
krets, [Stms (Aliases GPU)]
poststms) <- forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
mk
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarAliases, Type)]
pes,
KernelBody (Aliases GPU)
kbody
{ kernelBodyStms :: Stms (Aliases GPU)
kernelBodyStms = forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPU)
kbody forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
bodystms,
kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
},
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
poststms
)
where
([VName]
gtids, [SubExp]
_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
onRet :: PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
onRet (PatElem VName
v (VarAliases, Type)
v_dec) KernelResult
ret
| Just (DesiredUpdate VName
bindee_nm (VarAliases, Type)
bindee_dec Certs
_cs VName
src Slice SubExp
slice VName
_val) <-
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (LetDec (Aliases GPU))]
updates = do
Returns ResultManifest
_ Certs
cs SubExp
se <- forall a. a -> Maybe a
Just KernelResult
ret
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
let ([SubExp]
dims', [DimIndex SubExp]
slice') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. DimIndex d -> Maybe d
dimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall u. TypeBase Shape u -> [SubExp]
arrayDims (forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
bindee_dec)) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
in Shape -> Slice SubExp -> Bool
isFullSlice (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
([SubExp]
slice', Stms (Aliases GPU)
bodystms) <-
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope (Aliases GPU)
scope forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") forall a b. (a -> b) -> a -> b
$
forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids
let res_dims :: [SubExp]
res_dims = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice') forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd (VarAliases, Type)
bindee_dec
ret' :: KernelResult
ret' = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs (forall d. [d] -> ShapeBase d
Shape [SubExp]
res_dims) VName
src [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]
VName
v_aliased <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall dec. VName -> dec -> PatElem dec
PatElem VName
bindee_nm (VarAliases, Type)
bindee_dec,
Stms (Aliases GPU)
bodystms,
KernelResult
ret',
forall rep. [Stm rep] -> Stms rep
stmsFromList
[ forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v_aliased forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice,
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_aliased
]
)
onRet PatElem (VarAliases, Type)
pe KernelResult
ret =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (VarAliases, Type)
pe, forall a. Monoid a => a
mempty, KernelResult
ret, forall a. Monoid a => a
mempty)
lowerUpdateIntoLoop ::
( Buildable rep,
BuilderOps rep,
Aliased rep,
LetDec rep ~ (als, Type),
MonadFreshNames m
) =>
Scope rep ->
[DesiredUpdate (LetDec rep)] ->
Pat (LetDec rep) ->
[(FParam rep, SubExp)] ->
LoopForm rep ->
Body rep ->
Maybe
( m
( [Stm rep],
[Stm rep],
[Ident],
[(FParam rep, SubExp)],
Body rep
)
)
lowerUpdateIntoLoop :: forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
Body rep))
lowerUpdateIntoLoop Scope rep
scope [DesiredUpdate (LetDec rep)]
updates Pat (LetDec rep)
pat [(FParam rep, SubExp)]
val LoopForm rep
form Body rep
body = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam rep, SubExp)]
val forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body) forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`notNameIn` Names
als
m [LoopResultSummary (als, Type)]
mk_in_place_map <- forall rep (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (LetDec rep)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(FParam rep, SubExp)]
val
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
[LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
([(Param DeclType, SubExp)]
val', [Stm rep]
prestms, [Stm rep]
poststms) <- forall (m :: * -> *) rep als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
let valpat :: [Ident]
valpat = forall {a}. [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
idxsubsts :: IndexSubstitutions
idxsubsts = forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
(IndexSubstitutions
idxsubsts', Stms rep
newstms) <- forall (m :: * -> *) rep.
(MonadFreshNames m, BuilderOps rep, Buildable rep, Aliased rep) =>
IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
substituteIndices IndexSubstitutions
idxsubsts forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body
(Result
body_res, Stms rep
res_stms) <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (als, Type)]
in_place_map IndexSubstitutions
idxsubsts'
let body' :: Body rep
body' = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
newstms forall a. Semigroup a => a -> a -> a
<> Stms rep
res_stms) Result
body_res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm rep]
prestms, [Stm rep]
poststms, [Ident]
valpat, [(Param DeclType, SubExp)]
val', Body rep
body')
where
usedInBody :: Names
usedInBody =
forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope rep
scope) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Body rep
body forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form
resmap :: [(SubExpRes, Ident)]
resmap = forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Body rep -> Result
bodyResult Body rep
body) forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (LetDec rep)
pat
mkMerges ::
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)] ->
m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges :: forall (m :: * -> *) rep als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
(([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm rep]
prestms, [Stm rep]
poststms)) <-
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall a b. [Either a b] -> ([a], [b])
partitionEithers forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {rep} {rep} {a}.
(MonadFreshNames m, MonadWriter ([Stm rep], [Stm rep]) m,
Buildable rep, Buildable rep) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param DeclType, SubExp)]
origmerge forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm rep]
prestms, [Stm rep]
poststms)
mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
VName
source <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
VName
precopy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
let source_t :: Type
source_t = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
elm_t :: Type
elm_t = Type
source_t forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` forall d. Slice d -> [d]
sliceDims (forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [ forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
source Type
source_t] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update
Safety
Unsafe
(forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
(Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd
forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
],
[ forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
precopy Type
elm_t] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index
(forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
(Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update),
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident (forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) Type
elm_t] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
]
)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a b. b -> Either a b
Right
( forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
mergename (forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
VName -> SubExp
Var VName
source
)
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
mkResAndPat :: [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (a, Type)]
summaries =
let ([Ident]
origpat, [Ident]
extrapat) = forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {a}. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (a, Type)]
summaries
in [Ident]
origpat forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat
mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) (forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
| Bool
otherwise =
forall a b. a -> Either a b
Left (forall dec. LoopResultSummary dec -> Ident
inPatAs LoopResultSummary (a, Type)
summary)
summariseLoop ::
( Aliased rep,
MonadFreshNames m
) =>
Scope rep ->
[DesiredUpdate (als, Type)] ->
Names ->
[(SubExpRes, Ident)] ->
[(Param DeclType, SubExp)] ->
Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: forall rep (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge
where
summariseLoopResult :: (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExpRes
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
| Just DesiredUpdate (als, Type)
update <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
if Names
usedInBody Names -> Names -> Bool
`namesIntersect` forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
lookupAliases (forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update) Scope rep
scope
then forall a. Maybe a
Nothing
else
if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam
then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
VName
lowered_array <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
forall (f :: * -> *) a. Applicative f => a -> f a
pure
LoopResultSummary
{ resultSubExp :: SubExpRes
resultSubExp = SubExpRes
se,
inPatAs :: Ident
inPatAs = Ident
v,
mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit),
relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate =
forall a. a -> Maybe a
Just
( DesiredUpdate (als, Type)
update,
VName
lowered_array,
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update
)
}
else forall a. Maybe a
Nothing
summariseLoopResult (SubExpRes, Ident)
_ (Param DeclType, SubExp)
_ =
forall a. Maybe a
Nothing
hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> [SubExp]
arrayDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType
merge_param_names :: [VName]
merge_param_names = forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge
loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
loopInvariant Constant {} = Bool
True
data LoopResultSummary dec = LoopResultSummary
{ forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp :: SubExpRes,
forall dec. LoopResultSummary dec -> Ident
inPatAs :: Ident,
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp),
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
}
deriving (Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopResultSummary dec] -> ShowS
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
show :: LoopResultSummary dec -> String
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
Show)
indexSubstitutions :: Typed dec => [LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions :: forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {t}.
Typed t =>
LoopResultSummary t
-> Maybe (VName, (Certs, VName, Type, Slice SubExp))
getSubstitution
where
getSubstitution :: LoopResultSummary t
-> Maybe (VName, (Certs, VName, Type, Slice SubExp))
getSubstitution LoopResultSummary t
res = do
(DesiredUpdate VName
_ t
_ Certs
cs VName
_ Slice SubExp
is VName
_, VName
nm, t
dec) <- forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary t
res
let name :: VName
name = forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary t
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
name, (Certs
cs, VName
nm, forall t. Typed t => t -> Type
typeOf t
dec, Slice SubExp
is))
manipulateResult ::
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)] ->
IndexSubstitutions ->
m (Result, Stms rep)
manipulateResult :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (LetDec rep)]
summaries IndexSubstitutions
substs = do
let (Result
orig_ses, Result
updated_ses) = forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {dec}. LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes [LoopResultSummary (LetDec rep)]
summaries
(Result
subst_ses, [Stm rep]
res_stms) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {rep} {t}.
(MonadFreshNames f, MonadWriter [Stm rep] f, Buildable rep,
Typed t) =>
SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes Result
updated_ses IndexSubstitutions
substs
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
orig_ses forall a. [a] -> [a] -> [a]
++ Result
subst_ses, forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
res_stms)
where
unchangedRes :: LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes LoopResultSummary dec
summary =
case forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
Maybe (DesiredUpdate dec, VName, dec)
Nothing -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
Just (DesiredUpdate dec, VName, dec)
_ -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
substRes :: SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes (SubExpRes Certs
res_cs (Var VName
res_v)) (VName
subst_v, (Certs
_, VName
nm, t
_, Slice SubExp
_))
| VName
res_v forall a. Eq a => a -> a -> Bool
== VName
subst_v =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
res_cs forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
substRes (SubExpRes Certs
res_cs SubExp
res_se) (VName
_, (Certs
cs, VName
nm, t
dec, Slice [DimIndex SubExp]
is)) = do
Ident
v' <- forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (forall a. [a] -> [a] -> [a]
++ String
"_updated") forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf t
dec
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
[ forall rep. Certs -> Stm rep -> Stm rep
certify (Certs
res_cs forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident
v'] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
nm (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall t. Typed t => t -> Type
typeOf t
dec) [DimIndex SubExp]
is) SubExp
res_se
]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'