{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
( lowerUpdateGPU,
lowerUpdate,
LowerUpdate,
DesiredUpdate (..),
)
where
import Control.Monad
import Control.Monad.Writer
import Data.Either
import Data.List (find, unzip4)
import Data.Maybe (isNothing, mapMaybe)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.Optimise.InPlaceLowering.SubstituteIndices
data DesiredUpdate dec = DesiredUpdate
{
forall dec. DesiredUpdate dec -> VName
updateName :: VName,
forall dec. DesiredUpdate dec -> dec
updateType :: dec,
forall dec. DesiredUpdate dec -> Certs
updateCerts :: Certs,
forall dec. DesiredUpdate dec -> VName
updateSource :: VName,
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp,
forall dec. DesiredUpdate dec -> VName
updateValue :: VName
}
deriving (Int -> DesiredUpdate dec -> ShowS
[DesiredUpdate dec] -> ShowS
DesiredUpdate dec -> String
(Int -> DesiredUpdate dec -> ShowS)
-> (DesiredUpdate dec -> String)
-> ([DesiredUpdate dec] -> ShowS)
-> Show (DesiredUpdate dec)
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
show :: DesiredUpdate dec -> String
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
showList :: [DesiredUpdate dec] -> ShowS
Show)
instance Functor DesiredUpdate where
a -> b
f fmap :: forall a b. (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u {updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u}
updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
name ==) (VName -> Bool)
-> (DesiredUpdate dec -> VName) -> DesiredUpdate dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate dec -> VName
forall dec. DesiredUpdate dec -> VName
updateValue
type LowerUpdate rep m =
Scope (Aliases rep) ->
Stm (Aliases rep) ->
[DesiredUpdate (LetDec (Aliases rep))] ->
Maybe (m [Stm (Aliases rep)])
lowerUpdate ::
( MonadFreshNames m,
Buildable rep,
LetDec rep ~ Type,
AliasableRep rep
) =>
LowerUpdate rep m
lowerUpdate :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
AliasableRep rep) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases rep)
scope (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body)) [DesiredUpdate (LetDec (Aliases rep))]
updates = do
m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
[(Param DeclType, SubExp)], Body (Aliases rep))
canDo <- Scope (Aliases rep)
-> [DesiredUpdate (LetDec (Aliases rep))]
-> Pat (LetDec (Aliases rep))
-> [(FParam (Aliases rep), SubExp)]
-> LoopForm (Aliases rep)
-> Body (Aliases rep)
-> Maybe
(m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
[(FParam (Aliases rep), SubExp)], Body (Aliases rep)))
forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
Body rep))
lowerUpdateIntoLoop Scope (Aliases rep)
scope [DesiredUpdate (LetDec (Aliases rep))]
updates Pat (LetDec (Aliases rep))
pat [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body
m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a. a -> Maybe a
Just (m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)]))
-> m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a b. (a -> b) -> a -> b
$ do
([Stm (Aliases rep)]
prestms, [Stm (Aliases rep)]
poststms, [Ident]
pat', [(Param DeclType, SubExp)]
merge', Body (Aliases rep)
body') <- m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
[(Param DeclType, SubExp)], Body (Aliases rep))
canDo
[Stm (Aliases rep)] -> m [Stm (Aliases rep)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> m [Stm (Aliases rep)])
-> [Stm (Aliases rep)] -> m [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
[Stm (Aliases rep)]
prestms
[Stm (Aliases rep)] -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. [a] -> [a] -> [a]
++ [ Certs -> Stm (Aliases rep) -> Stm (Aliases rep)
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux) (Stm (Aliases rep) -> Stm (Aliases rep))
-> Stm (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
[Ident] -> Exp (Aliases rep) -> Stm (Aliases rep)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
pat' (Exp (Aliases rep) -> Stm (Aliases rep))
-> Exp (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
[(FParam (Aliases rep), SubExp)]
-> LoopForm (Aliases rep)
-> Body (Aliases rep)
-> Exp (Aliases rep)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases rep), SubExp)]
merge' LoopForm (Aliases rep)
form Body (Aliases rep)
body'
]
[Stm (Aliases rep)] -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases rep)]
poststms
lowerUpdate
Scope (Aliases rep)
_
(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (SubExp (Var VName
v))))
[DesiredUpdate VName
bindee_nm LetDec (Aliases rep)
bindee_dec Certs
cs VName
src (Slice [DimIndex SubExp]
is) VName
val]
| Pat (VarAliases, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarAliases, Type)
Pat (LetDec (Aliases rep))
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
let is' :: Slice SubExp
is' = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice ((VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
LetDec (Aliases rep)
bindee_dec) [DimIndex SubExp]
is
in m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a. a -> Maybe a
Just (m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)]))
-> ([Stm (Aliases rep)] -> m [Stm (Aliases rep)])
-> [Stm (Aliases rep)]
-> Maybe (m [Stm (Aliases rep)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm (Aliases rep)] -> m [Stm (Aliases rep)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)]))
-> [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a b. (a -> b) -> a -> b
$
[ Certs -> Stm (Aliases rep) -> Stm (Aliases rep)
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (Stm (Aliases rep) -> Stm (Aliases rep))
-> Stm (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
[Ident] -> Exp (Aliases rep) -> Stm (Aliases rep)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
LetDec (Aliases rep)
bindee_dec] (Exp (Aliases rep) -> Stm (Aliases rep))
-> Exp (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Aliases rep)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Aliases rep)) -> BasicOp -> Exp (Aliases rep)
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
v Slice SubExp
is' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
VName -> SubExp
Var VName
val
]
lowerUpdate Scope (Aliases rep)
_ Stm (Aliases rep)
_ [DesiredUpdate (LetDec (Aliases rep))]
_ =
Maybe (m [Stm (Aliases rep)])
forall a. Maybe a
Nothing
lowerUpdateGPU :: MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU :: forall (m :: * -> *). MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU
Scope (Aliases GPU)
scope
(Let Pat (LetDec (Aliases GPU))
pat StmAux (ExpDec (Aliases GPU))
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody))))
[DesiredUpdate (LetDec (Aliases GPU))]
updates
| (DesiredUpdate (VarAliases, Type) -> Bool)
-> [DesiredUpdate (VarAliases, Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Pat (VarAliases, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat) (VName -> Bool)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates,
Bool -> Bool
not Bool
source_used_in_kbody = do
m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
mk <- Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
(m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
(m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
space KernelBody (Aliases GPU)
kbody
m [Stm (Aliases GPU)] -> Maybe (m [Stm (Aliases GPU)])
forall a. a -> Maybe a
Just (m [Stm (Aliases GPU)] -> Maybe (m [Stm (Aliases GPU)]))
-> m [Stm (Aliases GPU)] -> Maybe (m [Stm (Aliases GPU)])
forall a b. (a -> b) -> a -> b
$ do
(Pat (VarAliases, Type)
pat', KernelBody (Aliases GPU)
kbody', Stms (Aliases GPU)
poststms) <- m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
mk
let cs :: Certs
cs = StmAux (VarAliases, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ())
StmAux (ExpDec (Aliases GPU))
aux Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> (DesiredUpdate (VarAliases, Type) -> Certs)
-> [DesiredUpdate (VarAliases, Type)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DesiredUpdate (VarAliases, Type) -> Certs
forall dec. DesiredUpdate dec -> Certs
updateCerts [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates
[Stm (Aliases GPU)] -> m [Stm (Aliases GPU)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases GPU)] -> m [Stm (Aliases GPU)])
-> [Stm (Aliases GPU)] -> m [Stm (Aliases GPU)]
forall a b. (a -> b) -> a -> b
$
Certs -> Stm (Aliases GPU) -> Stm (Aliases GPU)
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Pat (LetDec (Aliases GPU))
-> StmAux (ExpDec (Aliases GPU))
-> Exp (Aliases GPU)
-> Stm (Aliases GPU)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat' StmAux (ExpDec (Aliases GPU))
aux (Exp (Aliases GPU) -> Stm (Aliases GPU))
-> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ Op (Aliases GPU) -> Exp (Aliases GPU)
forall rep. Op rep -> Exp rep
Op (Op (Aliases GPU) -> Exp (Aliases GPU))
-> Op (Aliases GPU) -> Exp (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases GPU) -> HostOp SOAC (Aliases GPU)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel (Aliases GPU) -> HostOp SOAC (Aliases GPU))
-> SegOp SegLevel (Aliases GPU) -> HostOp SOAC (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases GPU)
-> SegOp SegLevel (Aliases GPU)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody')
Stm (Aliases GPU) -> [Stm (Aliases GPU)] -> [Stm (Aliases GPU)]
forall a. a -> [a] -> [a]
: Stms (Aliases GPU) -> [Stm (Aliases GPU)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Aliases GPU)
poststms
where
source_used_in_kbody :: Bool
source_used_in_kbody =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope (Aliases GPU) -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) (Names -> [VName]
namesToList (KernelBody (Aliases GPU) -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases GPU)
kbody)))
Names -> Names -> Bool
`namesIntersect` [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((DesiredUpdate (VarAliases, Type) -> Names)
-> [DesiredUpdate (VarAliases, Type)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Scope (Aliases GPU) -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) (VName -> Names)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates)
lowerUpdateGPU Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates = LowerUpdate GPU m
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
AliasableRep rep) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates
lowerUpdatesIntoSegMap ::
MonadFreshNames m =>
Scope (Aliases GPU) ->
Pat (LetDec (Aliases GPU)) ->
[DesiredUpdate (LetDec (Aliases GPU))] ->
SegSpace ->
KernelBody (Aliases GPU) ->
Maybe
( m
( Pat (LetDec (Aliases GPU)),
KernelBody (Aliases GPU),
Stms (Aliases GPU)
)
)
lowerUpdatesIntoSegMap :: forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
(m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
kspace KernelBody (Aliases GPU)
kbody = do
[m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
mk <- (PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))))
-> [PatElem (VarAliases, Type)]
-> [KernelResult]
-> Maybe
[m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
onRet (Pat (VarAliases, Type) -> [PatElem (VarAliases, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat) (KernelBody (Aliases GPU) -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases GPU)
kbody)
m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
-> Maybe
(m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
-> Maybe
(m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))))
-> m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
-> Maybe
(m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU)))
forall a b. (a -> b) -> a -> b
$ do
([PatElem (VarAliases, Type)]
pes, [Stms (Aliases GPU)]
bodystms, [KernelResult]
krets, [Stms (Aliases GPU)]
poststms) <- [(PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
-> ([PatElem (VarAliases, Type)], [Stms (Aliases GPU)],
[KernelResult], [Stms (Aliases GPU)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
-> ([PatElem (VarAliases, Type)], [Stms (Aliases GPU)],
[KernelResult], [Stms (Aliases GPU)]))
-> m [(PatElem (VarAliases, Type), Stms (Aliases GPU),
KernelResult, Stms (Aliases GPU))]
-> m ([PatElem (VarAliases, Type)], [Stms (Aliases GPU)],
[KernelResult], [Stms (Aliases GPU)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
-> m [(PatElem (VarAliases, Type), Stms (Aliases GPU),
KernelResult, Stms (Aliases GPU))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))]
mk
(Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
-> m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
Stms (Aliases GPU))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [PatElem (VarAliases, Type)] -> Pat (VarAliases, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarAliases, Type)]
pes,
KernelBody (Aliases GPU)
kbody
{ kernelBodyStms :: Stms (Aliases GPU)
kernelBodyStms = KernelBody (Aliases GPU) -> Stms (Aliases GPU)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPU)
kbody Stms (Aliases GPU) -> Stms (Aliases GPU) -> Stms (Aliases GPU)
forall a. Semigroup a => a -> a -> a
<> [Stms (Aliases GPU)] -> Stms (Aliases GPU)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
bodystms,
kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
},
[Stms (Aliases GPU)] -> Stms (Aliases GPU)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
poststms
)
where
([VName]
gtids, [SubExp]
_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
onRet :: PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
onRet (PatElem VName
v (VarAliases, Type)
v_dec) KernelResult
ret
| Just (DesiredUpdate VName
bindee_nm (VarAliases, Type)
bindee_dec Certs
_cs VName
src Slice SubExp
slice VName
_val) <-
(DesiredUpdate (VarAliases, Type) -> Bool)
-> [DesiredUpdate (VarAliases, Type)]
-> Maybe (DesiredUpdate (VarAliases, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates = do
Returns ResultManifest
_ Certs
cs SubExp
se <- KernelResult -> Maybe KernelResult
forall a. a -> Maybe a
Just KernelResult
ret
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
let ([SubExp]
dims', [DimIndex SubExp]
slice') =
[(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp]))
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> ([SubExp], [DimIndex SubExp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> [(SubExp, DimIndex SubExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp, DimIndex SubExp) -> Bool)
-> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((SubExp, DimIndex SubExp) -> Maybe SubExp)
-> (SubExp, DimIndex SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix (DimIndex SubExp -> Maybe SubExp)
-> ((SubExp, DimIndex SubExp) -> DimIndex SubExp)
-> (SubExp, DimIndex SubExp)
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, DimIndex SubExp) -> DimIndex SubExp
forall a b. (a, b) -> b
snd) ([(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp]))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp])
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [DimIndex SubExp] -> [(SubExp, DimIndex SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims ((VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
bindee_dec)) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
in Shape -> Slice SubExp -> Bool
isFullSlice ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
forall a. a -> Maybe a
Just (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
forall a b. (a -> b) -> a -> b
$ do
([SubExp]
slice', Stms (Aliases GPU)
bodystms) <-
(BuilderT (Aliases GPU) m [SubExp]
-> Scope (Aliases GPU) -> m ([SubExp], Stms (Aliases GPU)))
-> Scope (Aliases GPU)
-> BuilderT (Aliases GPU) m [SubExp]
-> m ([SubExp], Stms (Aliases GPU))
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT (Aliases GPU) m [SubExp]
-> Scope (Aliases GPU) -> m ([SubExp], Stms (Aliases GPU))
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope (Aliases GPU)
scope (BuilderT (Aliases GPU) m [SubExp]
-> m ([SubExp], Stms (Aliases GPU)))
-> BuilderT (Aliases GPU) m [SubExp]
-> m ([SubExp], Stms (Aliases GPU))
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName -> BuilderT (Aliases GPU) m SubExp)
-> [TPrimExp Int64 VName] -> BuilderT (Aliases GPU) m [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> TPrimExp Int64 VName -> BuilderT (Aliases GPU) m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") ([TPrimExp Int64 VName] -> BuilderT (Aliases GPU) m [SubExp])
-> [TPrimExp Int64 VName] -> BuilderT (Aliases GPU) m [SubExp]
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids
let res_dims :: [SubExp]
res_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice') ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall a b. (a, b) -> b
snd (VarAliases, Type)
bindee_dec
ret' :: KernelResult
ret' = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
res_dims) VName
src [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]
VName
v_aliased <- VName -> m VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
v
(PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( VName -> (VarAliases, Type) -> PatElem (VarAliases, Type)
forall dec. VName -> dec -> PatElem dec
PatElem VName
bindee_nm (VarAliases, Type)
bindee_dec,
Stms (Aliases GPU)
bodystms,
KernelResult
ret',
[Stm (Aliases GPU)] -> Stms (Aliases GPU)
forall rep. [Stm rep] -> Stms rep
stmsFromList
[ [Ident] -> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v_aliased (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] (Exp (Aliases GPU) -> Stm (Aliases GPU))
-> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Aliases GPU)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Aliases GPU)) -> BasicOp -> Exp (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice,
[Ident] -> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] (Exp (Aliases GPU) -> Stm (Aliases GPU))
-> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Aliases GPU)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Aliases GPU)) -> BasicOp -> Exp (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v_aliased
]
)
onRet PatElem (VarAliases, Type)
pe KernelResult
ret =
m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
forall a. a -> Maybe a
Just (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> Maybe
(m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU)))
forall a b. (a -> b) -> a -> b
$ (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
Stms (Aliases GPU))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (VarAliases, Type)
pe, Stms (Aliases GPU)
forall a. Monoid a => a
mempty, KernelResult
ret, Stms (Aliases GPU)
forall a. Monoid a => a
mempty)
lowerUpdateIntoLoop ::
( Buildable rep,
BuilderOps rep,
Aliased rep,
LetDec rep ~ (als, Type),
MonadFreshNames m
) =>
Scope rep ->
[DesiredUpdate (LetDec rep)] ->
Pat (LetDec rep) ->
[(FParam rep, SubExp)] ->
LoopForm rep ->
Body rep ->
Maybe
( m
( [Stm rep],
[Stm rep],
[Ident],
[(FParam rep, SubExp)],
Body rep
)
)
lowerUpdateIntoLoop :: forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
Body rep))
lowerUpdateIntoLoop Scope rep
scope [DesiredUpdate (LetDec rep)]
updates Pat (LetDec rep)
pat [(FParam rep, SubExp)]
val LoopForm rep
form Body rep
body = do
[((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam rep, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body rep -> [Names]
forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`notNameIn` Names
als
m [LoopResultSummary (als, Type)]
mk_in_place_map <- Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall rep (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (als, Type)]
[DesiredUpdate (LetDec rep)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam rep, SubExp)]
val
m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep)
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep))
forall a. a -> Maybe a
Just (m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep)
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep)))
-> m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep)
-> Maybe
(m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep))
forall a b. (a -> b) -> a -> b
$ do
[LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
([(Param DeclType, SubExp)]
val', [Stm rep]
prestms, [Stm rep]
poststms) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
forall (m :: * -> *) rep als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
let valpat :: [Ident]
valpat = [LoopResultSummary (als, Type)] -> [Ident]
forall {a}. [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
idxsubsts :: IndexSubstitutions
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions
forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
(IndexSubstitutions
idxsubsts', Stms rep
newstms) <- IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, BuilderOps rep, Buildable rep, Aliased rep) =>
IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
substituteIndices IndexSubstitutions
idxsubsts (Stms rep -> m (IndexSubstitutions, Stms rep))
-> Stms rep -> m (IndexSubstitutions, Stms rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
(Result
body_res, Stms rep
res_stms) <- [LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetDec rep)]
in_place_map IndexSubstitutions
idxsubsts'
let body' :: Body rep
body' = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
newstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
res_stms) Result
body_res
([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep)
-> m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
Body rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm rep]
prestms, [Stm rep]
poststms, [Ident]
valpat, [(Param DeclType, SubExp)]
val', Body rep
body')
where
usedInBody :: Names
usedInBody =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope rep -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope rep
scope) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn Body rep
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm rep -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form
resmap :: [(SubExpRes, Ident)]
resmap = Result -> [Ident] -> [(SubExpRes, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
body) ([Ident] -> [(SubExpRes, Ident)])
-> [Ident] -> [(SubExpRes, Ident)]
forall a b. (a -> b) -> a -> b
$ Pat (als, Type) -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (als, Type)
Pat (LetDec rep)
pat
mkMerges ::
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)] ->
m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges :: forall (m :: * -> *) rep als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
(([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm rep]
prestms, [Stm rep]
poststms)) <-
WriterT
([Stm rep], [Stm rep])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm rep], [Stm rep]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Stm rep], [Stm rep])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm rep], [Stm rep])))
-> WriterT
([Stm rep], [Stm rep])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
([Stm rep], [Stm rep]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
([Stm rep], [Stm rep])
m
[Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
([Stm rep], [Stm rep])
m
([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
-> WriterT
([Stm rep], [Stm rep])
m
(Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
([Stm rep], [Stm rep])
m
[Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM LoopResultSummary (als, Type)
-> WriterT
([Stm rep], [Stm rep])
m
(Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall {rep} {rep} {m :: * -> *} {a}.
(MonadFreshNames m, MonadWriter ([Stm rep], [Stm rep]) m,
Buildable rep, Buildable rep) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm rep]
prestms, [Stm rep]
poststms)
mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
VName
precopy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
elm_t :: Type
elm_t = Type
source_t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims (DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
([Stm rep], [Stm rep]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [ [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp
(BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update
Safety
Unsafe
(DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
(Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice (Slice SubExp -> [DimIndex SubExp])
-> Slice SubExp -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
(SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd
((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
],
[ [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
precopy Type
elm_t] (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index
(DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
(Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice (Slice SubExp -> [DimIndex SubExp])
-> Slice SubExp -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update),
[Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) Type
elm_t] (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
precopy
]
)
Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$
(Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right
( Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
mergename (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
VName -> SubExp
Var VName
source
)
| Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
mkResAndPat :: [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (a, Type)]
summaries =
let ([Ident]
origpat, [Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (a, Type) -> Either Ident Ident)
-> [LoopResultSummary (a, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (a, Type) -> Either Ident Ident
forall {a}. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (a, Type)]
summaries
in [Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat
mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
| Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
| Bool
otherwise =
Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall dec. LoopResultSummary dec -> Ident
inPatAs LoopResultSummary (a, Type)
summary)
summariseLoop ::
( Aliased rep,
MonadFreshNames m
) =>
Scope rep ->
[DesiredUpdate (als, Type)] ->
Names ->
[(SubExpRes, Ident)] ->
[(Param DeclType, SubExp)] ->
Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: forall rep (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
[m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge
where
summariseLoopResult :: (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExpRes
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
| Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
if Names
usedInBody Names -> Names -> Bool
`namesIntersect` VName -> Scope rep -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
lookupAliases (DesiredUpdate (als, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update) Scope rep
scope
then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
else
if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam
then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
LoopResultSummary
{ resultSubExp :: SubExpRes
resultSubExp = SubExpRes
se,
inPatAs :: Ident
inPatAs = Ident
v,
mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit),
relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate =
(DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just
( DesiredUpdate (als, Type)
update,
VName
lowered_array,
DesiredUpdate (als, Type) -> (als, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update
)
}
else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
summariseLoopResult (SubExpRes, Ident)
_ (Param DeclType, SubExp)
_ =
Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType
merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge
loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
loopInvariant Constant {} = Bool
True
data LoopResultSummary dec = LoopResultSummary
{ forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp :: SubExpRes,
forall dec. LoopResultSummary dec -> Ident
inPatAs :: Ident,
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp),
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
}
deriving (Int -> LoopResultSummary dec -> ShowS
[LoopResultSummary dec] -> ShowS
LoopResultSummary dec -> String
(Int -> LoopResultSummary dec -> ShowS)
-> (LoopResultSummary dec -> String)
-> ([LoopResultSummary dec] -> ShowS)
-> Show (LoopResultSummary dec)
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
show :: LoopResultSummary dec -> String
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
showList :: [LoopResultSummary dec] -> ShowS
Show)
indexSubstitutions :: Typed dec => [LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions :: forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions = (LoopResultSummary dec -> Maybe (VName, IndexSubstitution))
-> [LoopResultSummary dec] -> IndexSubstitutions
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary dec -> Maybe (VName, IndexSubstitution)
forall {t}.
Typed t =>
LoopResultSummary t -> Maybe (VName, IndexSubstitution)
getSubstitution
where
getSubstitution :: LoopResultSummary t -> Maybe (VName, IndexSubstitution)
getSubstitution LoopResultSummary t
res = do
(DesiredUpdate VName
_ t
_ Certs
cs VName
_ Slice SubExp
is VName
_, VName
nm, t
dec) <- LoopResultSummary t -> Maybe (DesiredUpdate t, VName, t)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary t
res
let name :: VName
name = Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary t -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary t
res
(VName, IndexSubstitution) -> Maybe (VName, IndexSubstitution)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
name, (Certs
cs, VName
nm, t -> Type
forall t. Typed t => t -> Type
typeOf t
dec, Slice SubExp
is))
manipulateResult ::
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)] ->
IndexSubstitutions ->
m (Result, Stms rep)
manipulateResult :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (LetDec rep)]
summaries IndexSubstitutions
substs = do
let (Result
orig_ses, Result
updated_ses) = [Either SubExpRes SubExpRes] -> (Result, Result)
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExpRes SubExpRes] -> (Result, Result))
-> [Either SubExpRes SubExpRes] -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetDec rep) -> Either SubExpRes SubExpRes)
-> [LoopResultSummary (LetDec rep)] -> [Either SubExpRes SubExpRes]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetDec rep) -> Either SubExpRes SubExpRes
forall {dec}. LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes [LoopResultSummary (LetDec rep)]
summaries
(Result
subst_ses, [Stm rep]
res_stms) <- WriterT [Stm rep] m Result -> m (Result, [Stm rep])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm rep] m Result -> m (Result, [Stm rep]))
-> WriterT [Stm rep] m Result -> m (Result, [Stm rep])
forall a b. (a -> b) -> a -> b
$ (SubExpRes
-> (VName, IndexSubstitution) -> WriterT [Stm rep] m SubExpRes)
-> Result -> IndexSubstitutions -> WriterT [Stm rep] m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExpRes
-> (VName, IndexSubstitution) -> WriterT [Stm rep] m SubExpRes
forall {rep} {f :: * -> *} {t}.
(MonadFreshNames f, MonadWriter [Stm rep] f, Buildable rep,
Typed t) =>
SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes Result
updated_ses IndexSubstitutions
substs
(Result, Stms rep) -> m (Result, Stms rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
orig_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
subst_ses, [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
res_stms)
where
unchangedRes :: LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes LoopResultSummary dec
summary =
case LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
Maybe (DesiredUpdate dec, VName, dec)
Nothing -> SubExpRes -> Either SubExpRes SubExpRes
forall a b. a -> Either a b
Left (SubExpRes -> Either SubExpRes SubExpRes)
-> SubExpRes -> Either SubExpRes SubExpRes
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExpRes
forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
Just (DesiredUpdate dec, VName, dec)
_ -> SubExpRes -> Either SubExpRes SubExpRes
forall a b. b -> Either a b
Right (SubExpRes -> Either SubExpRes SubExpRes)
-> SubExpRes -> Either SubExpRes SubExpRes
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExpRes
forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
substRes :: SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes (SubExpRes Certs
res_cs (Var VName
res_v)) (VName
subst_v, (Certs
_, VName
nm, t
_, Slice SubExp
_))
| VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
SubExpRes -> f SubExpRes
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> f SubExpRes) -> SubExpRes -> f SubExpRes
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
res_cs (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
substRes (SubExpRes Certs
res_cs SubExp
res_se) (VName
_, (Certs
cs, VName
nm, t
dec, Slice [DimIndex SubExp]
is)) = do
Ident
v' <- ShowS -> Ident -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_updated") (Ident -> f Ident) -> Ident -> f Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
dec
[Stm rep] -> f ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
[ Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify (Certs
res_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident
v'] (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
nm (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
dec) [DimIndex SubExp]
is) SubExp
res_se
]
SubExpRes -> f SubExpRes
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> f SubExpRes) -> SubExpRes -> f SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes (VName -> SubExpRes) -> VName -> SubExpRes
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'