{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module implements an optimisation that moves in-place
-- updates into/before loops where possible, with the end goal of
-- minimising memory copies.  As an example, consider this program:
--
-- @
--   let r =
--     loop (r1 = r0) = for i < n do
--       let a = r1[i]
--       let r1[i] = a * i
--       in r1
--   ...
--   let x = y with [k] <- r in
--   ...
-- @
--
-- We want to turn this into the following:
--
-- @
--   let x0 = y with [k] <- r0
--   loop (x = x0) = for i < n do
--     let a = a[k,i]
--     let x[k,i] = a * i
--     in x
--   let r = x[k] in
--   ...
-- @
--
-- The intent is that we are also going to optimise the new data
-- movement (in the @x0@-binding), possibly by changing how @r0@ is
-- defined.  For the above transformation to be valid, a number of
-- conditions must be fulfilled:
--
--    (1) @r@ must not be consumed after the original in-place update.
--
--    (2) @k@ and @y@ must be available at the beginning of the loop.
--
--    (3) @x@ must be visible whenever @r@ is visible.  (This means
--    that both @x@ and @r@ must be bound in the same t'Body'.)
--
--    (4) If @x@ is consumed at a point after the loop, @r@ must not
--    be used after that point.
--
--    (5) The size of @r1@ is invariant inside the loop.
--
--    (6) The value @r@ must come from something that we can actually
--    optimise (e.g. not a function parameter).
--
--    (7) @y@ (or its aliases) may not be used inside the body of the
--    loop.
--
--    (8) The result of the loop may not alias the merge parameter
--    @r1@.
--
--    (9) @y@ or its aliases may not be used after the loop.
--
-- FIXME: the implementation is not finished yet.  Specifically, not
-- all of the above conditions are checked.
module Futhark.Optimise.InPlaceLowering
  ( inPlaceLoweringGPU,
    inPlaceLoweringSeq,
    inPlaceLoweringMC,
  )
where

import Control.Monad.RWS
import Data.Map.Strict qualified as M
import Data.Ord (comparing)
import Futhark.Analysis.Alias
import Futhark.Builder
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.IR.Seq (Seq)
import Futhark.Optimise.InPlaceLowering.LowerIntoStm
import Futhark.Pass
import Futhark.Util (nubByOrd)

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringGPU :: Pass GPU GPU
inPlaceLoweringGPU :: Pass GPU GPU
inPlaceLoweringGPU = forall {k1} (rep :: k1).
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp GPU
onKernelOp forall (m :: * -> *). MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringSeq :: Pass Seq Seq
inPlaceLoweringSeq :: Pass Seq Seq
inPlaceLoweringSeq = forall {k1} (rep :: k1).
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 CanBeAliased (Op rep)) =>
LowerUpdate rep m
lowerUpdate

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringMC :: Pass MC MC
inPlaceLoweringMC :: Pass MC MC
inPlaceLoweringMC = forall {k1} (rep :: k1).
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp MC
onMCOp forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 CanBeAliased (Op rep)) =>
LowerUpdate rep m
lowerUpdate

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLowering ::
  Constraints rep =>
  OnOp rep ->
  LowerUpdate rep (ForwardingM rep) ->
  Pass rep rep
inPlaceLowering :: forall {k1} (rep :: k1).
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp rep
onOp LowerUpdate rep (ForwardingM rep)
lower =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"In-place lowering" String
"Lower in-place updates into loops" forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Prog (Aliases rep) -> Prog rep
removeProgAliases
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases rep) -> PassM (Stms (Aliases rep))
optimiseConsts Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
optimiseFunDef
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
aliasAnalysis
  where
    optimiseConsts :: Stms (Aliases rep) -> PassM (Stms (Aliases rep))
optimiseConsts Stms (Aliases rep)
stms =
      forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) a.
LowerUpdate rep (ForwardingM rep)
-> OnOp rep -> ForwardingM rep a -> VNameSource -> (a, VNameSource)
runForwardingM LowerUpdate rep (ForwardingM rep)
lower OnOp rep
onOp forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms (Aliases rep)
stms) (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

    optimiseFunDef :: Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
optimiseFunDef Stms (Aliases rep)
consts FunDef (Aliases rep)
fundec =
      forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) a.
LowerUpdate rep (ForwardingM rep)
-> OnOp rep -> ForwardingM rep a -> VNameSource -> (a, VNameSource)
runForwardingM LowerUpdate rep (ForwardingM rep)
lower OnOp rep
onOp forall a b. (a -> b) -> a -> b
$
          forall {k} {rep :: k} {a}.
[Stm (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
descend (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms (Aliases rep)
consts) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) a.
[FParam (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
bindingFParams (forall {k} (rep :: k). FunDef rep -> [FParam rep]
funDefParams FunDef (Aliases rep)
fundec) forall a b. (a -> b) -> a -> b
$ do
              Body (Aliases rep)
body <- forall {k} (rep :: k).
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef (Aliases rep)
fundec
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ FunDef (Aliases rep)
fundec {funDefBody :: Body (Aliases rep)
funDefBody = Body (Aliases rep)
body}

    descend :: [Stm (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
descend [] ForwardingM rep a
m = ForwardingM rep a
m
    descend (Stm (Aliases rep)
stm : [Stm (Aliases rep)]
stms) ForwardingM rep a
m = forall {k} (rep :: k) a.
Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingStm Stm (Aliases rep)
stm forall a b. (a -> b) -> a -> b
$ [Stm (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
descend [Stm (Aliases rep)]
stms ForwardingM rep a
m

type Constraints rep = (Buildable rep, CanBeAliased (Op rep))

optimiseBody ::
  Constraints rep =>
  Body (Aliases rep) ->
  ForwardingM rep (Body (Aliases rep))
optimiseBody :: forall {k} (rep :: k).
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody (Body BodyDec (Aliases rep)
als Stms (Aliases rep)
stms Result
res) = do
  [Stm (Aliases rep)]
stms' <- forall {k} (rep :: k) a. ForwardingM rep a -> ForwardingM rep a
deepen forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms (Aliases rep)
stms) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} {rep :: k}. SubExp -> ForwardingM rep ()
seen forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Aliases rep)
als (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm (Aliases rep)]
stms') Result
res
  where
    seen :: SubExp -> ForwardingM rep ()
seen Constant {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    seen (Var VName
v) = forall {k} (rep :: k). VName -> ForwardingM rep ()
seenVar VName
v

optimiseStms ::
  Constraints rep =>
  [Stm (Aliases rep)] ->
  ForwardingM rep () ->
  ForwardingM rep [Stm (Aliases rep)]
optimiseStms :: forall {k} (rep :: k).
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms [] ForwardingM rep ()
m = ForwardingM rep ()
m forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
optimiseStms (Stm (Aliases rep)
stm : [Stm (Aliases rep)]
stms) ForwardingM rep ()
m = do
  ([Stm (Aliases rep)]
stms', BottomUp rep
bup) <- forall {k} (rep :: k) a.
ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
tapBottomUp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a.
Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingStm Stm (Aliases rep)
stm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms [Stm (Aliases rep)]
stms ForwardingM rep ()
m
  Stm (Aliases rep)
stm' <- forall {k} (rep :: k).
Constraints rep =>
Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
optimiseInStm Stm (Aliases rep)
stm
  -- XXX: unfortunate that we cannot handle duplicate update values.
  -- Would be good to improve this.  See inplacelowering6.fut.
  case forall a. (a -> a -> Ordering) -> [a] -> [a]
nubByOrd (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing forall dec. DesiredUpdate dec -> VName
updateValue)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`notNameIn` forall {k} (rep :: k). BottomUp rep -> Names
bottomUpSeen BottomUp rep
bup) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateSource) -- (9)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
boundHere) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DesiredUpdate dec -> VName
updateValue)
    forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BottomUp rep -> [DesiredUpdate (LetDec (Aliases rep))]
forwardThese BottomUp rep
bup of
    [] -> do
      forall {k} {k} {rep :: k} {rep :: k}.
(LetDec rep ~ (VarAliases, LetDec rep), Buildable rep,
 CanBeAliased (Op rep), FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Stm rep -> ForwardingM rep ()
checkIfForwardableUpdate Stm (Aliases rep)
stm'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm (Aliases rep)
stm' forall a. a -> [a] -> [a]
: [Stm (Aliases rep)]
stms'
    [DesiredUpdate (VarAliases, LetDec rep)]
updates -> do
      Scope (Aliases rep)
-> Stm (Aliases rep)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> Maybe (ForwardingM rep [Stm (Aliases rep)])
lower <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k).
TopDown rep -> LowerUpdate rep (ForwardingM rep)
topLowerUpdate
      Scope (Aliases rep)
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope

      -- If we forward any updates, we need to remove them from stms'.
      let updated_names :: [VName]
updated_names =
            forall a b. (a -> b) -> [a] -> [b]
map forall dec. DesiredUpdate dec -> VName
updateName [DesiredUpdate (VarAliases, LetDec rep)]
updates
          notUpdated :: Stm (Aliases rep) -> Bool
notUpdated =
            Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
updated_names) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat

      -- Condition (5) and (7) are assumed to be checked by
      -- lowerUpdate.
      case Scope (Aliases rep)
-> Stm (Aliases rep)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> Maybe (ForwardingM rep [Stm (Aliases rep)])
lower Scope (Aliases rep)
scope Stm (Aliases rep)
stm' [DesiredUpdate (VarAliases, LetDec rep)]
updates of
        Just ForwardingM rep [Stm (Aliases rep)]
lowering -> do
          [Stm (Aliases rep)]
new_stms <- ForwardingM rep [Stm (Aliases rep)]
lowering
          [Stm (Aliases rep)]
new_stms' <- forall {k} (rep :: k).
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms [Stm (Aliases rep)]
new_stms forall a b. (a -> b) -> a -> b
$ forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp rep
bup {forwardThese :: [DesiredUpdate (LetDec (Aliases rep))]
forwardThese = []}
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Stm (Aliases rep)]
new_stms' forall a. [a] -> [a] -> [a]
++ forall a. (a -> Bool) -> [a] -> [a]
filter Stm (Aliases rep) -> Bool
notUpdated [Stm (Aliases rep)]
stms'
        Maybe (ForwardingM rep [Stm (Aliases rep)])
Nothing -> do
          forall {k} {k} {rep :: k} {rep :: k}.
(LetDec rep ~ (VarAliases, LetDec rep), Buildable rep,
 CanBeAliased (Op rep), FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Stm rep -> ForwardingM rep ()
checkIfForwardableUpdate Stm (Aliases rep)
stm'
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm (Aliases rep)
stm' forall a. a -> [a] -> [a]
: [Stm (Aliases rep)]
stms'
  where
    boundHere :: [VName]
boundHere = forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm (Aliases rep)
stm

    checkIfForwardableUpdate :: Stm rep -> ForwardingM rep ()
checkIfForwardableUpdate (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
_ ExpDec rep
_) Exp rep
e)
      | Pat [PatElem VName
v LetDec rep
dec] <- Pat (LetDec rep)
pat,
        BasicOp (Update Safety
Unsafe VName
src Slice SubExp
slice (Var VName
ve)) <- Exp rep
e =
          forall {k} (rep :: k).
Constraints rep =>
VName
-> VName
-> LetDec (Aliases rep)
-> Certs
-> VName
-> Slice SubExp
-> ForwardingM rep ()
maybeForward VName
ve VName
v LetDec rep
dec Certs
cs VName
src Slice SubExp
slice
    checkIfForwardableUpdate Stm rep
stm' =
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} (rep :: k). VName -> ForwardingM rep ()
seenVar forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm'

optimiseInStm :: Constraints rep => Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
optimiseInStm :: forall {k} (rep :: k).
Constraints rep =>
Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
optimiseInStm (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
dec Exp (Aliases rep)
e) =
  forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
Constraints rep =>
Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
optimiseExp Exp (Aliases rep)
e

optimiseExp :: Constraints rep => Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
optimiseExp :: forall {k} (rep :: k).
Constraints rep =>
Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
optimiseExp (DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body) =
  forall {k} (rep :: k) a.
Scope (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Aliases rep)
form) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a.
[FParam (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
bindingFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
merge) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody Body (Aliases rep)
body
optimiseExp (Op Op (Aliases rep)
op) = do
  OpWithAliases (Op rep) -> ForwardingM rep (OpWithAliases (Op rep))
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k). TopDown rep -> OnOp rep
topOnOp
  forall {k} (rep :: k). Op rep -> Exp rep
Op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpWithAliases (Op rep) -> ForwardingM rep (OpWithAliases (Op rep))
f Op (Aliases rep)
op
optimiseExp Exp (Aliases rep)
e = forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper (Aliases rep) (Aliases rep) (ForwardingM rep)
optimise Exp (Aliases rep)
e
  where
    optimise :: Mapper (Aliases rep) (Aliases rep) (ForwardingM rep)
optimise =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope (Aliases rep)
-> Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
mapOnBody = forall a b. a -> b -> a
const forall {k} (rep :: k).
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody
        }

onSegOp ::
  (Buildable rep, CanBeAliased (Op rep)) =>
  SegOp lvl (Aliases rep) ->
  ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp :: forall {k} (rep :: k) lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp SegOp lvl (Aliases rep)
op =
  forall {k} (rep :: k) a.
Scope (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp lvl (Aliases rep)
op)) forall a b. (a -> b) -> a -> b
$ do
    let mapper :: SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
mapper = forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper {mapOnSegOpBody :: KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
mapOnSegOpBody = forall {k} {rep :: k}.
(Buildable rep, CanBeAliased (Op rep)) =>
KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
onKernelBody}
        onKernelBody :: KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
onKernelBody KernelBody (Aliases rep)
kbody = do
          [Stm (Aliases rep)]
stms <-
            forall {k} (rep :: k) a. ForwardingM rep a -> ForwardingM rep a
deepen forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kbody)) forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} (rep :: k). VName -> ForwardingM rep ()
seenVar forall a b. (a -> b) -> a -> b
$
                  Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
                    forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$
                      forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases rep)
kbody
          forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody (Aliases rep)
kbody {kernelBodyStms :: Stms (Aliases rep)
kernelBodyStms = forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm (Aliases rep)]
stms}
    forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}.
SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
mapper SegOp lvl (Aliases rep)
op

onMCOp :: OnOp MC
onMCOp :: OnOp MC
onMCOp (ParOp Maybe (SegOp () (Aliases MC))
par_op SegOp () (Aliases MC)
op) = forall {k} (rep :: k) op.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp Maybe (SegOp () (Aliases MC))
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} (rep :: k) lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp SegOp () (Aliases MC)
op
onMCOp Op (Aliases MC)
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure Op (Aliases MC)
op

onKernelOp :: OnOp GPU
onKernelOp :: OnOp GPU
onKernelOp (SegOp SegOp SegLevel (Aliases GPU)
op) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp SegOp SegLevel (Aliases GPU)
op
onKernelOp Op (Aliases GPU)
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure Op (Aliases GPU)
op

data Entry rep = Entry
  { forall {k} (rep :: k). Entry rep -> Int
entryNumber :: Int,
    forall {k} (rep :: k). Entry rep -> Names
entryAliases :: Names,
    forall {k} (rep :: k). Entry rep -> Int
entryDepth :: Int,
    forall {k} (rep :: k). Entry rep -> Bool
entryOptimisable :: Bool,
    forall {k} (rep :: k). Entry rep -> NameInfo (Aliases rep)
entryType :: NameInfo (Aliases rep)
  }

type VTable rep = M.Map VName (Entry rep)

type OnOp rep = Op (Aliases rep) -> ForwardingM rep (Op (Aliases rep))

data TopDown rep = TopDown
  { forall {k} (rep :: k). TopDown rep -> Int
topDownCounter :: Int,
    forall {k} (rep :: k). TopDown rep -> VTable rep
topDownTable :: VTable rep,
    forall {k} (rep :: k). TopDown rep -> Int
topDownDepth :: Int,
    forall {k} (rep :: k).
TopDown rep -> LowerUpdate rep (ForwardingM rep)
topLowerUpdate :: LowerUpdate rep (ForwardingM rep),
    forall {k} (rep :: k). TopDown rep -> OnOp rep
topOnOp :: OnOp rep
  }

data BottomUp rep = BottomUp
  { forall {k} (rep :: k). BottomUp rep -> Names
bottomUpSeen :: Names,
    forall {k} (rep :: k).
BottomUp rep -> [DesiredUpdate (LetDec (Aliases rep))]
forwardThese :: [DesiredUpdate (LetDec (Aliases rep))]
  }

instance Semigroup (BottomUp rep) where
  BottomUp Names
seen1 [DesiredUpdate (LetDec (Aliases rep))]
forward1 <> :: BottomUp rep -> BottomUp rep -> BottomUp rep
<> BottomUp Names
seen2 [DesiredUpdate (LetDec (Aliases rep))]
forward2 =
    forall {k} (rep :: k).
Names -> [DesiredUpdate (LetDec (Aliases rep))] -> BottomUp rep
BottomUp (Names
seen1 forall a. Semigroup a => a -> a -> a
<> Names
seen2) ([DesiredUpdate (LetDec (Aliases rep))]
forward1 forall a. Semigroup a => a -> a -> a
<> [DesiredUpdate (LetDec (Aliases rep))]
forward2)

instance Monoid (BottomUp rep) where
  mempty :: BottomUp rep
mempty = forall {k} (rep :: k).
Names -> [DesiredUpdate (LetDec (Aliases rep))] -> BottomUp rep
BottomUp forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

newtype ForwardingM rep a = ForwardingM (RWS (TopDown rep) (BottomUp rep) VNameSource a)
  deriving
    ( forall a. a -> ForwardingM rep a
forall k (rep :: k). Applicative (ForwardingM rep)
forall k (rep :: k) a. a -> ForwardingM rep a
forall k (rep :: k) a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall k (rep :: k) a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ForwardingM rep a
$creturn :: forall k (rep :: k) a. a -> ForwardingM rep a
>> :: forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
$c>> :: forall k (rep :: k) a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
>>= :: forall a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
$c>>= :: forall k (rep :: k) a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
Monad,
      forall a. a -> ForwardingM rep a
forall k (rep :: k). Functor (ForwardingM rep)
forall k (rep :: k) a. a -> ForwardingM rep a
forall k (rep :: k) a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
forall k (rep :: k) a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall k (rep :: k) a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall k (rep :: k) a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
$c<* :: forall k (rep :: k) a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
*> :: forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
$c*> :: forall k (rep :: k) a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
liftA2 :: forall a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
$cliftA2 :: forall k (rep :: k) a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
<*> :: forall a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
$c<*> :: forall k (rep :: k) a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
pure :: forall a. a -> ForwardingM rep a
$cpure :: forall k (rep :: k) a. a -> ForwardingM rep a
Applicative,
      forall k (rep :: k) a b.
a -> ForwardingM rep b -> ForwardingM rep a
forall k (rep :: k) a b.
(a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall a b. a -> ForwardingM rep b -> ForwardingM rep a
forall a b. (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ForwardingM rep b -> ForwardingM rep a
$c<$ :: forall k (rep :: k) a b.
a -> ForwardingM rep b -> ForwardingM rep a
fmap :: forall a b. (a -> b) -> ForwardingM rep a -> ForwardingM rep b
$cfmap :: forall k (rep :: k) a b.
(a -> b) -> ForwardingM rep a -> ForwardingM rep b
Functor,
      MonadReader (TopDown rep),
      MonadWriter (BottomUp rep),
      MonadState VNameSource
    )

instance MonadFreshNames (ForwardingM rep) where
  getNameSource :: ForwardingM rep VNameSource
getNameSource = forall s (m :: * -> *). MonadState s m => m s
get
  putNameSource :: VNameSource -> ForwardingM rep ()
putNameSource = forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance Constraints rep => HasScope (Aliases rep) (ForwardingM rep) where
  askScope :: ForwardingM rep (Scope (Aliases rep))
askScope = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall {k} (rep :: k). Entry rep -> NameInfo (Aliases rep)
entryType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k). TopDown rep -> VTable rep
topDownTable

runForwardingM ::
  LowerUpdate rep (ForwardingM rep) ->
  OnOp rep ->
  ForwardingM rep a ->
  VNameSource ->
  (a, VNameSource)
runForwardingM :: forall {k} (rep :: k) a.
LowerUpdate rep (ForwardingM rep)
-> OnOp rep -> ForwardingM rep a -> VNameSource -> (a, VNameSource)
runForwardingM LowerUpdate rep (ForwardingM rep)
f OnOp rep
g (ForwardingM RWS (TopDown rep) (BottomUp rep) VNameSource a
m) VNameSource
src =
  let (a
x, VNameSource
src', BottomUp rep
_) = forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (TopDown rep) (BottomUp rep) VNameSource a
m TopDown rep
emptyTopDown VNameSource
src
   in (a
x, VNameSource
src')
  where
    emptyTopDown :: TopDown rep
emptyTopDown =
      TopDown
        { topDownCounter :: Int
topDownCounter = Int
0,
          topDownTable :: VTable rep
topDownTable = forall k a. Map k a
M.empty,
          topDownDepth :: Int
topDownDepth = Int
0,
          topLowerUpdate :: LowerUpdate rep (ForwardingM rep)
topLowerUpdate = LowerUpdate rep (ForwardingM rep)
f,
          topOnOp :: OnOp rep
topOnOp = OnOp rep
g
        }

bindingParams ::
  (dec -> NameInfo (Aliases rep)) ->
  [Param dec] ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingParams :: forall {k} dec (rep :: k) a.
(dec -> NameInfo (Aliases rep))
-> [Param dec] -> ForwardingM rep a -> ForwardingM rep a
bindingParams dec -> NameInfo (Aliases rep)
f [Param dec]
params = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable rep
vtable Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y) ->
  let entry :: Param dec -> (VName, Entry rep)
entry Param dec
fparam =
        ( forall dec. Param dec -> VName
paramName Param dec
fparam,
          forall {k} (rep :: k).
Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
Entry Int
n forall a. Monoid a => a
mempty Int
d Bool
False forall a b. (a -> b) -> a -> b
$ dec -> NameInfo (Aliases rep)
f forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> dec
paramDec Param dec
fparam
        )
      entries :: VTable rep
entries = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Param dec -> (VName, Entry rep)
entry [Param dec]
params
   in forall {k} (rep :: k).
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown (Int
n forall a. Num a => a -> a -> a
+ Int
1) (forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable rep
entries VTable rep
vtable) Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y

bindingFParams ::
  [FParam (Aliases rep)] ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingFParams :: forall {k} (rep :: k) a.
[FParam (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
bindingFParams = forall {k} dec (rep :: k) a.
(dec -> NameInfo (Aliases rep))
-> [Param dec] -> ForwardingM rep a -> ForwardingM rep a
bindingParams forall {k} (rep :: k). FParamInfo rep -> NameInfo rep
FParamName

bindingScope ::
  Scope (Aliases rep) ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingScope :: forall {k} (rep :: k) a.
Scope (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingScope Scope (Aliases rep)
scope = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable rep
vtable Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y) ->
  let entries :: VTable rep
entries = forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo (Aliases rep) -> Entry rep
entry Scope (Aliases rep)
scope
      infoAliases :: NameInfo rep -> Names
infoAliases (LetName (VarAliases
aliases, b
_)) = VarAliases -> Names
unAliases VarAliases
aliases
      infoAliases NameInfo rep
_ = forall a. Monoid a => a
mempty
      entry :: NameInfo (Aliases rep) -> Entry rep
entry NameInfo (Aliases rep)
info = forall {k} (rep :: k).
Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
Entry Int
n (forall {k} {rep :: k} {b}.
(LetDec rep ~ (VarAliases, b)) =>
NameInfo rep -> Names
infoAliases NameInfo (Aliases rep)
info) Int
d Bool
False NameInfo (Aliases rep)
info
   in forall {k} (rep :: k).
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown (Int
n forall a. Num a => a -> a -> a
+ Int
1) (VTable rep
entries forall a. Semigroup a => a -> a -> a
<> VTable rep
vtable) Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y

bindingStm ::
  Stm (Aliases rep) ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingStm :: forall {k} (rep :: k) a.
Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingStm (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
_) = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable rep
vtable Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y) ->
  let entries :: VTable rep
entries = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map PatElem (VarAliases, LetDec rep) -> (VName, Entry rep)
entry forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases rep))
pat
      entry :: PatElem (VarAliases, LetDec rep) -> (VName, Entry rep)
entry PatElem (VarAliases, LetDec rep)
patElem =
        let (VarAliases
aliases, LetDec rep
_) = forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LetDec rep)
patElem
         in ( forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LetDec rep)
patElem,
              forall {k} (rep :: k).
Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
Entry Int
n (VarAliases -> Names
unAliases VarAliases
aliases) Int
d Bool
True forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). LetDec rep -> NameInfo rep
LetName forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LetDec rep)
patElem
            )
   in forall {k} (rep :: k).
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown (Int
n forall a. Num a => a -> a -> a
+ Int
1) (forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable rep
entries VTable rep
vtable) Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y

bindingNumber :: VName -> ForwardingM rep Int
bindingNumber :: forall {k} (rep :: k). VName -> ForwardingM rep Int
bindingNumber VName
name = do
  Maybe Int
res <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k). Entry rep -> Int
entryNumber forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). TopDown rep -> VTable rep
topDownTable
  case Maybe Int
res of
    Just Int
n -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n
    Maybe Int
Nothing ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"bindingNumber: variable "
          forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
          forall a. [a] -> [a] -> [a]
++ String
" not found."

deepen :: ForwardingM rep a -> ForwardingM rep a
deepen :: forall {k} (rep :: k) a. ForwardingM rep a -> ForwardingM rep a
deepen = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \TopDown rep
env -> TopDown rep
env {topDownDepth :: Int
topDownDepth = forall {k} (rep :: k). TopDown rep -> Int
topDownDepth TopDown rep
env forall a. Num a => a -> a -> a
+ Int
1}

areAvailableBefore :: Names -> VName -> ForwardingM rep Bool
areAvailableBefore :: forall {k} (rep :: k). Names -> VName -> ForwardingM rep Bool
areAvailableBefore Names
names VName
point = do
  Int
pointN <- forall {k} (rep :: k). VName -> ForwardingM rep Int
bindingNumber VName
point
  [Int]
nameNs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). VName -> ForwardingM rep Int
bindingNumber forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Ord a => a -> a -> Bool
< Int
pointN) [Int]
nameNs

isInCurrentBody :: VName -> ForwardingM rep Bool
isInCurrentBody :: forall {k} (rep :: k). VName -> ForwardingM rep Bool
isInCurrentBody VName
name = do
  Int
current <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k). TopDown rep -> Int
topDownDepth
  Maybe Int
res <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k). Entry rep -> Int
entryDepth forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). TopDown rep -> VTable rep
topDownTable
  case Maybe Int
res of
    Just Int
d -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int
d forall a. Eq a => a -> a -> Bool
== Int
current
    Maybe Int
Nothing ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"isInCurrentBody: variable "
          forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
          forall a. [a] -> [a] -> [a]
++ String
" not found."

isOptimisable :: VName -> ForwardingM rep Bool
isOptimisable :: forall {k} (rep :: k). VName -> ForwardingM rep Bool
isOptimisable VName
name = do
  Maybe Bool
res <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k). Entry rep -> Bool
entryOptimisable forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). TopDown rep -> VTable rep
topDownTable
  case Maybe Bool
res of
    Just Bool
b -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
b
    Maybe Bool
Nothing ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"isOptimisable: variable "
          forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
name
          forall a. [a] -> [a] -> [a]
++ String
" not found."

seenVar :: VName -> ForwardingM rep ()
seenVar :: forall {k} (rep :: k). VName -> ForwardingM rep ()
seenVar VName
name = do
  Names
aliases <-
    forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$
      forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty forall {k} (rep :: k). Entry rep -> Names
entryAliases
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). TopDown rep -> VTable rep
topDownTable
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a
mempty {bottomUpSeen :: Names
bottomUpSeen = VName -> Names
oneName VName
name forall a. Semigroup a => a -> a -> a
<> Names
aliases}

tapBottomUp :: ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
tapBottomUp :: forall {k} (rep :: k) a.
ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
tapBottomUp ForwardingM rep a
m = do
  (a
x, BottomUp rep
bup) <- forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ForwardingM rep a
m
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, BottomUp rep
bup)

maybeForward ::
  Constraints rep =>
  VName ->
  VName ->
  LetDec (Aliases rep) ->
  Certs ->
  VName ->
  Slice SubExp ->
  ForwardingM rep ()
maybeForward :: forall {k} (rep :: k).
Constraints rep =>
VName
-> VName
-> LetDec (Aliases rep)
-> Certs
-> VName
-> Slice SubExp
-> ForwardingM rep ()
maybeForward VName
v VName
dest_nm LetDec (Aliases rep)
dest_dec Certs
cs VName
src Slice SubExp
slice = do
  -- Checks condition (2)
  Bool
available <-
    (forall a. FreeIn a => a -> Names
freeIn VName
src forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Certs
cs)
      forall {k} (rep :: k). Names -> VName -> ForwardingM rep Bool
`areAvailableBefore` VName
v
  -- Check condition (3)
  Bool
samebody <- forall {k} (rep :: k). VName -> ForwardingM rep Bool
isInCurrentBody VName
v
  -- Check condition (6)
  Bool
optimisable <- forall {k} (rep :: k). VName -> ForwardingM rep Bool
isOptimisable VName
v
  Bool
not_prim <- Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> Bool
primType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
available Bool -> Bool -> Bool
&& Bool
samebody Bool -> Bool -> Bool
&& Bool
optimisable Bool -> Bool -> Bool
&& Bool
not_prim) forall a b. (a -> b) -> a -> b
$ do
    let fwd :: DesiredUpdate (VarAliases, LetDec rep)
fwd = forall dec.
VName
-> dec
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DesiredUpdate dec
DesiredUpdate VName
dest_nm LetDec (Aliases rep)
dest_dec Certs
cs VName
src Slice SubExp
slice VName
v
    forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a. Monoid a => a
mempty {forwardThese :: [DesiredUpdate (LetDec (Aliases rep))]
forwardThese = [DesiredUpdate (VarAliases, LetDec rep)
fwd]}