{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module implements an optimisation that moves in-place
-- updates into/before loops where possible, with the end goal of
-- minimising memory copies.  As an example, consider this program:
--
-- @
--   let r =
--     loop (r1 = r0) = for i < n do
--       let a = r1[i]
--       let r1[i] = a * i
--       in r1
--   ...
--   let x = y with [k] <- r in
--   ...
-- @
--
-- We want to turn this into the following:
--
-- @
--   let x0 = y with [k] <- r0
--   loop (x = x0) = for i < n do
--     let a = a[k,i]
--     let x[k,i] = a * i
--     in x
--   let r = x[k] in
--   ...
-- @
--
-- The intent is that we are also going to optimise the new data
-- movement (in the @x0@-binding), possibly by changing how @r0@ is
-- defined.  For the above transformation to be valid, a number of
-- conditions must be fulfilled:
--
--    (1) @r@ must not be consumed after the original in-place update.
--
--    (2) @k@ and @y@ must be available at the beginning of the loop.
--
--    (3) @x@ must be visible whenever @r@ is visible.  (This means
--    that both @x@ and @r@ must be bound in the same t'Body'.)
--
--    (4) If @x@ is consumed at a point after the loop, @r@ must not
--    be used after that point.
--
--    (5) The size of @r1@ is invariant inside the loop.
--
--    (6) The value @r@ must come from something that we can actually
--    optimise (e.g. not a function parameter).
--
--    (7) @y@ (or its aliases) may not be used inside the body of the
--    loop.
--
--    (8) The result of the loop may not alias the merge parameter
--    @r1@.
--
-- FIXME: the implementation is not finished yet.  Specifically, not
-- all of the above conditions are checked.
module Futhark.Optimise.InPlaceLowering
  ( inPlaceLoweringKernels,
    inPlaceLoweringSeq,
  )
where

import Control.Monad.RWS
import qualified Data.Map.Strict as M
import Futhark.Analysis.Alias
import Futhark.Binder
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.IR.Seq (Seq)
import Futhark.Optimise.InPlaceLowering.LowerIntoStm
import Futhark.Pass

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringKernels :: Pass Kernels Kernels
inPlaceLoweringKernels :: Pass Kernels Kernels
inPlaceLoweringKernels = OnOp Kernels
-> LowerUpdate Kernels (ForwardingM Kernels)
-> Pass Kernels Kernels
forall lore.
Constraints lore =>
OnOp lore -> LowerUpdate lore (ForwardingM lore) -> Pass lore lore
inPlaceLowering OnOp Kernels
onKernelOp LowerUpdate Kernels (ForwardingM Kernels)
forall (m :: * -> *). MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringSeq :: Pass Seq Seq
inPlaceLoweringSeq :: Pass Seq Seq
inPlaceLoweringSeq = OnOp Seq -> LowerUpdate Seq (ForwardingM Seq) -> Pass Seq Seq
forall lore.
Constraints lore =>
OnOp lore -> LowerUpdate lore (ForwardingM lore) -> Pass lore lore
inPlaceLowering OnOp Seq
forall (f :: * -> *) a. Applicative f => a -> f a
pure LowerUpdate Seq (ForwardingM Seq)
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, LetDec lore ~ Type,
 CanBeAliased (Op lore)) =>
LowerUpdate lore m
lowerUpdate

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLowering ::
  Constraints lore =>
  OnOp lore ->
  LowerUpdate lore (ForwardingM lore) ->
  Pass lore lore
inPlaceLowering :: OnOp lore -> LowerUpdate lore (ForwardingM lore) -> Pass lore lore
inPlaceLowering OnOp lore
onOp LowerUpdate lore (ForwardingM lore)
lower =
  String
-> String -> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"In-place lowering" String
"Lower in-place updates into loops" ((Prog lore -> PassM (Prog lore)) -> Pass lore lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall a b. (a -> b) -> a -> b
$
    (Prog (Aliases lore) -> Prog lore)
-> PassM (Prog (Aliases lore)) -> PassM (Prog lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases lore) -> Prog lore
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases
      (PassM (Prog (Aliases lore)) -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog (Aliases lore)))
-> Prog lore
-> PassM (Prog lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (Aliases lore) -> PassM (Stms (Aliases lore)))
-> (Stms (Aliases lore)
    -> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore)))
-> Prog (Aliases lore)
-> PassM (Prog (Aliases lore))
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms (Aliases lore) -> PassM (Stms (Aliases lore))
optimiseConsts Stms (Aliases lore)
-> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore))
optimiseFunDef
      (Prog (Aliases lore) -> PassM (Prog (Aliases lore)))
-> (Prog lore -> Prog (Aliases lore))
-> Prog lore
-> PassM (Prog (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog lore -> Prog (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
aliasAnalysis
  where
    optimiseConsts :: Stms (Aliases lore) -> PassM (Stms (Aliases lore))
optimiseConsts Stms (Aliases lore)
stms =
      (VNameSource -> (Stms (Aliases lore), VNameSource))
-> PassM (Stms (Aliases lore))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms (Aliases lore), VNameSource))
 -> PassM (Stms (Aliases lore)))
-> (VNameSource -> (Stms (Aliases lore), VNameSource))
-> PassM (Stms (Aliases lore))
forall a b. (a -> b) -> a -> b
$
        LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore (Stms (Aliases lore))
-> VNameSource
-> (Stms (Aliases lore), VNameSource)
forall lore a.
LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore a
-> VNameSource
-> (a, VNameSource)
runForwardingM LowerUpdate lore (ForwardingM lore)
lower OnOp lore
onOp (ForwardingM lore (Stms (Aliases lore))
 -> VNameSource -> (Stms (Aliases lore), VNameSource))
-> ForwardingM lore (Stms (Aliases lore))
-> VNameSource
-> (Stms (Aliases lore), VNameSource)
forall a b. (a -> b) -> a -> b
$
          [Stm (Aliases lore)] -> Stms (Aliases lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm (Aliases lore)] -> Stms (Aliases lore))
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore (Stms (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms (Stms (Aliases lore) -> [Stm (Aliases lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases lore)
stms) (() -> ForwardingM lore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

    optimiseFunDef :: Stms (Aliases lore)
-> FunDef (Aliases lore) -> PassM (FunDef (Aliases lore))
optimiseFunDef Stms (Aliases lore)
consts FunDef (Aliases lore)
fundec =
      (VNameSource -> (FunDef (Aliases lore), VNameSource))
-> PassM (FunDef (Aliases lore))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (FunDef (Aliases lore), VNameSource))
 -> PassM (FunDef (Aliases lore)))
-> (VNameSource -> (FunDef (Aliases lore), VNameSource))
-> PassM (FunDef (Aliases lore))
forall a b. (a -> b) -> a -> b
$
        LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore (FunDef (Aliases lore))
-> VNameSource
-> (FunDef (Aliases lore), VNameSource)
forall lore a.
LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore a
-> VNameSource
-> (a, VNameSource)
runForwardingM LowerUpdate lore (ForwardingM lore)
lower OnOp lore
onOp (ForwardingM lore (FunDef (Aliases lore))
 -> VNameSource -> (FunDef (Aliases lore), VNameSource))
-> ForwardingM lore (FunDef (Aliases lore))
-> VNameSource
-> (FunDef (Aliases lore), VNameSource)
forall a b. (a -> b) -> a -> b
$
          [Stm (Aliases lore)]
-> ForwardingM lore (FunDef (Aliases lore))
-> ForwardingM lore (FunDef (Aliases lore))
forall lore a.
[Stm (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
descend (Stms (Aliases lore) -> [Stm (Aliases lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases lore)
consts) (ForwardingM lore (FunDef (Aliases lore))
 -> ForwardingM lore (FunDef (Aliases lore)))
-> ForwardingM lore (FunDef (Aliases lore))
-> ForwardingM lore (FunDef (Aliases lore))
forall a b. (a -> b) -> a -> b
$
            [FParam (Aliases lore)]
-> ForwardingM lore (FunDef (Aliases lore))
-> ForwardingM lore (FunDef (Aliases lore))
forall lore a.
[FParam (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
bindingFParams (FunDef (Aliases lore) -> [FParam (Aliases lore)]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef (Aliases lore)
fundec) (ForwardingM lore (FunDef (Aliases lore))
 -> ForwardingM lore (FunDef (Aliases lore)))
-> ForwardingM lore (FunDef (Aliases lore))
-> ForwardingM lore (FunDef (Aliases lore))
forall a b. (a -> b) -> a -> b
$ do
              Body (Aliases lore)
body <- Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
forall lore.
Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody (Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore)))
-> Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases lore) -> Body (Aliases lore)
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef (Aliases lore)
fundec
              FunDef (Aliases lore) -> ForwardingM lore (FunDef (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef (Aliases lore) -> ForwardingM lore (FunDef (Aliases lore)))
-> FunDef (Aliases lore)
-> ForwardingM lore (FunDef (Aliases lore))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases lore)
fundec {funDefBody :: Body (Aliases lore)
funDefBody = Body (Aliases lore)
body}

    descend :: [Stm (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
descend [] ForwardingM lore a
m = ForwardingM lore a
m
    descend (Stm (Aliases lore)
stm : [Stm (Aliases lore)]
stms) ForwardingM lore a
m = Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
forall lore a.
Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingStm Stm (Aliases lore)
stm (ForwardingM lore a -> ForwardingM lore a)
-> ForwardingM lore a -> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
descend [Stm (Aliases lore)]
stms ForwardingM lore a
m

type Constraints lore = (Bindable lore, CanBeAliased (Op lore))

optimiseBody ::
  Constraints lore =>
  Body (Aliases lore) ->
  ForwardingM lore (Body (Aliases lore))
optimiseBody :: Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody (Body BodyDec (Aliases lore)
als Stms (Aliases lore)
bnds Result
res) = do
  [Stm (Aliases lore)]
bnds' <-
    ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall lore a. ForwardingM lore a -> ForwardingM lore a
deepen (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
      [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms (Stms (Aliases lore) -> [Stm (Aliases lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases lore)
bnds) (ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
        (SubExp -> ForwardingM lore ()) -> Result -> ForwardingM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> ForwardingM lore ()
forall lore. SubExp -> ForwardingM lore ()
seen Result
res
  Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore)))
-> Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
forall a b. (a -> b) -> a -> b
$ BodyDec (Aliases lore)
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec (Aliases lore)
als ([Stm (Aliases lore)] -> Stms (Aliases lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Aliases lore)]
bnds') Result
res
  where
    seen :: SubExp -> ForwardingM lore ()
seen Constant {} = () -> ForwardingM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    seen (Var VName
v) = VName -> ForwardingM lore ()
forall lore. VName -> ForwardingM lore ()
seenVar VName
v

optimiseStms ::
  Constraints lore =>
  [Stm (Aliases lore)] ->
  ForwardingM lore () ->
  ForwardingM lore [Stm (Aliases lore)]
optimiseStms :: [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [] ForwardingM lore ()
m = ForwardingM lore ()
m ForwardingM lore ()
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return []
optimiseStms (Stm (Aliases lore)
bnd : [Stm (Aliases lore)]
bnds) ForwardingM lore ()
m = do
  ([Stm (Aliases lore)]
bnds', BottomUp lore
bup) <- ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore ([Stm (Aliases lore)], BottomUp lore)
forall lore a.
ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore ([Stm (Aliases lore)], BottomUp lore))
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore ([Stm (Aliases lore)], BottomUp lore)
forall a b. (a -> b) -> a -> b
$ Stm (Aliases lore)
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall lore a.
Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingStm Stm (Aliases lore)
bnd (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [Stm (Aliases lore)]
bnds ForwardingM lore ()
m
  Stm (Aliases lore)
bnd' <- Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
forall lore.
Constraints lore =>
Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm Stm (Aliases lore)
bnd
  case (DesiredUpdate (VarAliases, LetDec lore) -> Bool)
-> [DesiredUpdate (VarAliases, LetDec lore)]
-> [DesiredUpdate (VarAliases, LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
boundHere) (VName -> Bool)
-> (DesiredUpdate (VarAliases, LetDec lore) -> VName)
-> DesiredUpdate (VarAliases, LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, LetDec lore) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) ([DesiredUpdate (VarAliases, LetDec lore)]
 -> [DesiredUpdate (VarAliases, LetDec lore)])
-> [DesiredUpdate (VarAliases, LetDec lore)]
-> [DesiredUpdate (VarAliases, LetDec lore)]
forall a b. (a -> b) -> a -> b
$ BottomUp lore -> [DesiredUpdate (LetDec (Aliases lore))]
forall lore.
BottomUp lore -> [DesiredUpdate (LetDec (Aliases lore))]
forwardThese BottomUp lore
bup of
    [] -> do
      Stm (Aliases lore) -> ForwardingM lore ()
forall lore lore.
(Bindable lore, CanBeAliased (Op lore),
 LetDec lore ~ (VarAliases, LetDec lore)) =>
Stm lore -> ForwardingM lore ()
checkIfForwardableUpdate Stm (Aliases lore)
bnd'
      [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases lore)
bnd' Stm (Aliases lore) -> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. a -> [a] -> [a]
: [Stm (Aliases lore)]
bnds'
    [DesiredUpdate (VarAliases, LetDec lore)]
updates -> do
      Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (VarAliases, LetDec lore)]
-> Maybe (ForwardingM lore [Stm (Aliases lore)])
lower <- (TopDown lore
 -> Scope (Aliases lore)
 -> Stm (Aliases lore)
 -> [DesiredUpdate (VarAliases, LetDec lore)]
 -> Maybe (ForwardingM lore [Stm (Aliases lore)]))
-> ForwardingM
     lore
     (Scope (Aliases lore)
      -> Stm (Aliases lore)
      -> [DesiredUpdate (VarAliases, LetDec lore)]
      -> Maybe (ForwardingM lore [Stm (Aliases lore)]))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore
-> Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (VarAliases, LetDec lore)]
-> Maybe (ForwardingM lore [Stm (Aliases lore)])
forall lore. TopDown lore -> LowerUpdate lore (ForwardingM lore)
topLowerUpdate
      Scope (Aliases lore)
scope <- ForwardingM lore (Scope (Aliases lore))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope

      -- If we forward any updates, we need to remove them from bnds'.
      let updated_names :: [VName]
updated_names =
            (DesiredUpdate (VarAliases, LetDec lore) -> VName)
-> [DesiredUpdate (VarAliases, LetDec lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map DesiredUpdate (VarAliases, LetDec lore) -> VName
forall dec. DesiredUpdate dec -> VName
updateName [DesiredUpdate (VarAliases, LetDec lore)]
updates
          notUpdated :: Stm (Aliases lore) -> Bool
notUpdated =
            Bool -> Bool
not (Bool -> Bool)
-> (Stm (Aliases lore) -> Bool) -> Stm (Aliases lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
updated_names) ([VName] -> Bool)
-> (Stm (Aliases lore) -> [VName]) -> Stm (Aliases lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (VarAliases, LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (VarAliases, LetDec lore) -> [VName])
-> (Stm (Aliases lore) -> PatternT (VarAliases, LetDec lore))
-> Stm (Aliases lore)
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Aliases lore) -> PatternT (VarAliases, LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern

      -- Condition (5) and (7) are assumed to be checked by
      -- lowerUpdate.
      case Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (VarAliases, LetDec lore)]
-> Maybe (ForwardingM lore [Stm (Aliases lore)])
lower Scope (Aliases lore)
scope Stm (Aliases lore)
bnd' [DesiredUpdate (VarAliases, LetDec lore)]
updates of
        Just ForwardingM lore [Stm (Aliases lore)]
lowering -> do
          [Stm (Aliases lore)]
new_bnds <- ForwardingM lore [Stm (Aliases lore)]
lowering
          [Stm (Aliases lore)]
new_bnds' <-
            [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [Stm (Aliases lore)]
new_bnds (ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
              BottomUp lore -> ForwardingM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp lore
bup {forwardThese :: [DesiredUpdate (LetDec (Aliases lore))]
forwardThese = []}
          [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)]
new_bnds' [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ (Stm (Aliases lore) -> Bool)
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter Stm (Aliases lore) -> Bool
notUpdated [Stm (Aliases lore)]
bnds'
        Maybe (ForwardingM lore [Stm (Aliases lore)])
Nothing -> do
          Stm (Aliases lore) -> ForwardingM lore ()
forall lore lore.
(Bindable lore, CanBeAliased (Op lore),
 LetDec lore ~ (VarAliases, LetDec lore)) =>
Stm lore -> ForwardingM lore ()
checkIfForwardableUpdate Stm (Aliases lore)
bnd'
          [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases lore)
bnd' Stm (Aliases lore) -> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. a -> [a] -> [a]
: [Stm (Aliases lore)]
bnds'
  where
    boundHere :: [VName]
boundHere = PatternT (VarAliases, LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (VarAliases, LetDec lore) -> [VName])
-> PatternT (VarAliases, LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases lore) -> Pattern (Aliases lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm (Aliases lore)
bnd

    checkIfForwardableUpdate :: Stm lore -> ForwardingM lore ()
checkIfForwardableUpdate (Let Pattern lore
pat (StmAux Certificates
cs Attrs
_ ExpDec lore
_) Exp lore
e)
      | Pattern [] [PatElem VName
v LetDec lore
dec] <- Pattern lore
pat,
        BasicOp (Update VName
src Slice SubExp
slice (Var VName
ve)) <- Exp lore
e =
        VName
-> VName
-> LetDec (Aliases lore)
-> Certificates
-> VName
-> Slice SubExp
-> ForwardingM lore ()
forall lore.
Constraints lore =>
VName
-> VName
-> LetDec (Aliases lore)
-> Certificates
-> VName
-> Slice SubExp
-> ForwardingM lore ()
maybeForward VName
ve VName
v LetDec lore
LetDec (Aliases lore)
dec Certificates
cs VName
src Slice SubExp
slice
    checkIfForwardableUpdate Stm lore
_ = () -> ForwardingM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

optimiseInStm :: Constraints lore => Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm :: Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
dec Exp (Aliases lore)
e) =
  Pattern (Aliases lore)
-> StmAux (ExpDec (Aliases lore))
-> Exp (Aliases lore)
-> Stm (Aliases lore)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
dec (Exp (Aliases lore) -> Stm (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Stm (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
forall lore.
Constraints lore =>
Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp Exp (Aliases lore)
e

optimiseExp :: Constraints lore => Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp :: Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body) =
  Scope (Aliases lore)
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall lore a.
Scope (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingScope (LoopForm (Aliases lore) -> Scope (Aliases lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm (Aliases lore)
form) (ForwardingM lore (Exp (Aliases lore))
 -> ForwardingM lore (Exp (Aliases lore)))
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall a b. (a -> b) -> a -> b
$
    [FParam (Aliases lore)]
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall lore a.
[FParam (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
bindingFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val) (ForwardingM lore (Exp (Aliases lore))
 -> ForwardingM lore (Exp (Aliases lore)))
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall a b. (a -> b) -> a -> b
$
      [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Exp (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form (BodyT (Aliases lore) -> Exp (Aliases lore))
-> ForwardingM lore (BodyT (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore))
forall lore.
Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody BodyT (Aliases lore)
body
optimiseExp (Op Op (Aliases lore)
op) = do
  OpWithAliases (Op lore)
-> ForwardingM lore (OpWithAliases (Op lore))
f <- (TopDown lore
 -> OpWithAliases (Op lore)
 -> ForwardingM lore (OpWithAliases (Op lore)))
-> ForwardingM
     lore
     (OpWithAliases (Op lore)
      -> ForwardingM lore (OpWithAliases (Op lore)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore
-> OpWithAliases (Op lore)
-> ForwardingM lore (OpWithAliases (Op lore))
forall lore. TopDown lore -> OnOp lore
topOnOp
  OpWithAliases (Op lore) -> Exp (Aliases lore)
forall lore. Op lore -> ExpT lore
Op (OpWithAliases (Op lore) -> Exp (Aliases lore))
-> ForwardingM lore (OpWithAliases (Op lore))
-> ForwardingM lore (Exp (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpWithAliases (Op lore)
-> ForwardingM lore (OpWithAliases (Op lore))
f Op (Aliases lore)
OpWithAliases (Op lore)
op
optimiseExp Exp (Aliases lore)
e = Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
-> Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
optimise Exp (Aliases lore)
e
  where
    optimise :: Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
optimise =
      Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope (Aliases lore)
-> BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore))
mapOnBody = (BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore)))
-> Scope (Aliases lore)
-> BodyT (Aliases lore)
-> ForwardingM lore (BodyT (Aliases lore))
forall a b. a -> b -> a
const BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore))
forall lore.
Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody
        }

onKernelOp :: OnOp Kernels
onKernelOp :: OnOp Kernels
onKernelOp (SegOp op) =
  Scope (Aliases Kernels)
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall lore a.
Scope (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingScope (SegSpace -> Scope (Aliases Kernels)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegOp SegLevel (Aliases Kernels) -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel (Aliases Kernels)
op)) (ForwardingM
   Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
 -> ForwardingM
      Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels))))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
    let mapper :: SegOpMapper
  lvl (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
mapper = SegOpMapper
  lvl (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper {mapOnSegOpBody :: KernelBody (Aliases Kernels)
-> ForwardingM Kernels (KernelBody (Aliases Kernels))
mapOnSegOpBody = KernelBody (Aliases Kernels)
-> ForwardingM Kernels (KernelBody (Aliases Kernels))
forall lore.
(Bindable lore, CanBeAliased (Op lore)) =>
KernelBody (Aliases lore)
-> ForwardingM lore (KernelBody (Aliases lore))
onKernelBody}
        onKernelBody :: KernelBody (Aliases lore)
-> ForwardingM lore (KernelBody (Aliases lore))
onKernelBody KernelBody (Aliases lore)
kbody = do
          [Stm (Aliases lore)]
stms <-
            ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall lore a. ForwardingM lore a -> ForwardingM lore a
deepen (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
              [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms (Stms (Aliases lore) -> [Stm (Aliases lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList (KernelBody (Aliases lore) -> Stms (Aliases lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody (Aliases lore)
kbody)) (ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
                (VName -> ForwardingM lore ()) -> [VName] -> ForwardingM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> ForwardingM lore ()
forall lore. VName -> ForwardingM lore ()
seenVar ([VName] -> ForwardingM lore ()) -> [VName] -> ForwardingM lore ()
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn ([KernelResult] -> Names) -> [KernelResult] -> Names
forall a b. (a -> b) -> a -> b
$ KernelBody (Aliases lore) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases lore)
kbody
          KernelBody (Aliases lore)
-> ForwardingM lore (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody (Aliases lore)
kbody {kernelBodyStms :: Stms (Aliases lore)
kernelBodyStms = [Stm (Aliases lore)] -> Stms (Aliases lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Aliases lore)]
stms}
    SegOp SegLevel (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases Kernels)
 -> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> ForwardingM Kernels (SegOp SegLevel (Aliases Kernels))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  SegLevel (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
-> SegOp SegLevel (Aliases Kernels)
-> ForwardingM Kernels (SegOp SegLevel (Aliases Kernels))
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  SegLevel (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
forall lvl.
SegOpMapper
  lvl (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
mapper SegOp SegLevel (Aliases Kernels)
op
onKernelOp Op (Aliases Kernels)
op = HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall (m :: * -> *) a. Monad m => a -> m a
return Op (Aliases Kernels)
HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
op

data Entry lore = Entry
  { Entry lore -> Int
entryNumber :: Int,
    Entry lore -> Names
entryAliases :: Names,
    Entry lore -> Int
entryDepth :: Int,
    Entry lore -> Bool
entryOptimisable :: Bool,
    Entry lore -> NameInfo (Aliases lore)
entryType :: NameInfo (Aliases lore)
  }

type VTable lore = M.Map VName (Entry lore)

type OnOp lore = Op (Aliases lore) -> ForwardingM lore (Op (Aliases lore))

data TopDown lore = TopDown
  { TopDown lore -> Int
topDownCounter :: Int,
    TopDown lore -> VTable lore
topDownTable :: VTable lore,
    TopDown lore -> Int
topDownDepth :: Int,
    TopDown lore -> LowerUpdate lore (ForwardingM lore)
topLowerUpdate :: LowerUpdate lore (ForwardingM lore),
    TopDown lore -> OnOp lore
topOnOp :: OnOp lore
  }

data BottomUp lore = BottomUp
  { BottomUp lore -> Names
bottomUpSeen :: Names,
    BottomUp lore -> [DesiredUpdate (LetDec (Aliases lore))]
forwardThese :: [DesiredUpdate (LetDec (Aliases lore))]
  }

instance Semigroup (BottomUp lore) where
  BottomUp Names
seen1 [DesiredUpdate (LetDec (Aliases lore))]
forward1 <> :: BottomUp lore -> BottomUp lore -> BottomUp lore
<> BottomUp Names
seen2 [DesiredUpdate (LetDec (Aliases lore))]
forward2 =
    Names -> [DesiredUpdate (LetDec (Aliases lore))] -> BottomUp lore
forall lore.
Names -> [DesiredUpdate (LetDec (Aliases lore))] -> BottomUp lore
BottomUp (Names
seen1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
seen2) ([DesiredUpdate (VarAliases, LetDec lore)]
[DesiredUpdate (LetDec (Aliases lore))]
forward1 [DesiredUpdate (VarAliases, LetDec lore)]
-> [DesiredUpdate (VarAliases, LetDec lore)]
-> [DesiredUpdate (VarAliases, LetDec lore)]
forall a. Semigroup a => a -> a -> a
<> [DesiredUpdate (VarAliases, LetDec lore)]
[DesiredUpdate (LetDec (Aliases lore))]
forward2)

instance Monoid (BottomUp lore) where
  mempty :: BottomUp lore
mempty = Names -> [DesiredUpdate (LetDec (Aliases lore))] -> BottomUp lore
forall lore.
Names -> [DesiredUpdate (LetDec (Aliases lore))] -> BottomUp lore
BottomUp Names
forall a. Monoid a => a
mempty [DesiredUpdate (LetDec (Aliases lore))]
forall a. Monoid a => a
mempty

newtype ForwardingM lore a = ForwardingM (RWS (TopDown lore) (BottomUp lore) VNameSource a)
  deriving
    ( Applicative (ForwardingM lore)
a -> ForwardingM lore a
Applicative (ForwardingM lore)
-> (forall a b.
    ForwardingM lore a
    -> (a -> ForwardingM lore b) -> ForwardingM lore b)
-> (forall a b.
    ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b)
-> (forall a. a -> ForwardingM lore a)
-> Monad (ForwardingM lore)
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall lore. Applicative (ForwardingM lore)
forall a. a -> ForwardingM lore a
forall lore a. a -> ForwardingM lore a
forall a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall a b.
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall lore a b.
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ForwardingM lore a
$creturn :: forall lore a. a -> ForwardingM lore a
>> :: ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
$c>> :: forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
>>= :: ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
$c>>= :: forall lore a b.
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
$cp1Monad :: forall lore. Applicative (ForwardingM lore)
Monad,
      Functor (ForwardingM lore)
a -> ForwardingM lore a
Functor (ForwardingM lore)
-> (forall a. a -> ForwardingM lore a)
-> (forall a b.
    ForwardingM lore (a -> b)
    -> ForwardingM lore a -> ForwardingM lore b)
-> (forall a b c.
    (a -> b -> c)
    -> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c)
-> (forall a b.
    ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b)
-> (forall a b.
    ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a)
-> Applicative (ForwardingM lore)
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
forall lore. Functor (ForwardingM lore)
forall a. a -> ForwardingM lore a
forall lore a. a -> ForwardingM lore a
forall a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
forall a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall a b.
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall lore a b.
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
forall a b c.
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
forall lore a b c.
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
$c<* :: forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
*> :: ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
$c*> :: forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
liftA2 :: (a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
<*> :: ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
$c<*> :: forall lore a b.
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
pure :: a -> ForwardingM lore a
$cpure :: forall lore a. a -> ForwardingM lore a
$cp1Applicative :: forall lore. Functor (ForwardingM lore)
Applicative,
      a -> ForwardingM lore b -> ForwardingM lore a
(a -> b) -> ForwardingM lore a -> ForwardingM lore b
(forall a b. (a -> b) -> ForwardingM lore a -> ForwardingM lore b)
-> (forall a b. a -> ForwardingM lore b -> ForwardingM lore a)
-> Functor (ForwardingM lore)
forall a b. a -> ForwardingM lore b -> ForwardingM lore a
forall a b. (a -> b) -> ForwardingM lore a -> ForwardingM lore b
forall lore a b. a -> ForwardingM lore b -> ForwardingM lore a
forall lore a b.
(a -> b) -> ForwardingM lore a -> ForwardingM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ForwardingM lore b -> ForwardingM lore a
$c<$ :: forall lore a b. a -> ForwardingM lore b -> ForwardingM lore a
fmap :: (a -> b) -> ForwardingM lore a -> ForwardingM lore b
$cfmap :: forall lore a b.
(a -> b) -> ForwardingM lore a -> ForwardingM lore b
Functor,
      MonadReader (TopDown lore),
      MonadWriter (BottomUp lore),
      MonadState VNameSource
    )

instance MonadFreshNames (ForwardingM lore) where
  getNameSource :: ForwardingM lore VNameSource
getNameSource = ForwardingM lore VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
  putNameSource :: VNameSource -> ForwardingM lore ()
putNameSource = VNameSource -> ForwardingM lore ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance Constraints lore => HasScope (Aliases lore) (ForwardingM lore) where
  askScope :: ForwardingM lore (Scope (Aliases lore))
askScope = (Entry lore -> NameInfo (Aliases lore))
-> Map VName (Entry lore) -> Scope (Aliases lore)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Entry lore -> NameInfo (Aliases lore)
forall lore. Entry lore -> NameInfo (Aliases lore)
entryType (Map VName (Entry lore) -> Scope (Aliases lore))
-> ForwardingM lore (Map VName (Entry lore))
-> ForwardingM lore (Scope (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TopDown lore -> Map VName (Entry lore))
-> ForwardingM lore (Map VName (Entry lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable

runForwardingM ::
  LowerUpdate lore (ForwardingM lore) ->
  OnOp lore ->
  ForwardingM lore a ->
  VNameSource ->
  (a, VNameSource)
runForwardingM :: LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore a
-> VNameSource
-> (a, VNameSource)
runForwardingM LowerUpdate lore (ForwardingM lore)
f OnOp lore
g (ForwardingM RWS (TopDown lore) (BottomUp lore) VNameSource a
m) VNameSource
src =
  let (a
x, VNameSource
src', BottomUp lore
_) = RWS (TopDown lore) (BottomUp lore) VNameSource a
-> TopDown lore -> VNameSource -> (a, VNameSource, BottomUp lore)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (TopDown lore) (BottomUp lore) VNameSource a
m TopDown lore
emptyTopDown VNameSource
src
   in (a
x, VNameSource
src')
  where
    emptyTopDown :: TopDown lore
emptyTopDown =
      TopDown :: forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown
        { topDownCounter :: Int
topDownCounter = Int
0,
          topDownTable :: VTable lore
topDownTable = VTable lore
forall k a. Map k a
M.empty,
          topDownDepth :: Int
topDownDepth = Int
0,
          topLowerUpdate :: LowerUpdate lore (ForwardingM lore)
topLowerUpdate = LowerUpdate lore (ForwardingM lore)
f,
          topOnOp :: OnOp lore
topOnOp = OnOp lore
g
        }

bindingParams ::
  (dec -> NameInfo (Aliases lore)) ->
  [Param dec] ->
  ForwardingM lore a ->
  ForwardingM lore a
bindingParams :: (dec -> NameInfo (Aliases lore))
-> [Param dec] -> ForwardingM lore a -> ForwardingM lore a
bindingParams dec -> NameInfo (Aliases lore)
f [Param dec]
params = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable lore
vtable Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y) ->
  let entry :: Param dec -> (VName, Entry lore)
entry Param dec
fparam =
        ( Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
fparam,
          Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
forall lore.
Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
Entry Int
n Names
forall a. Monoid a => a
mempty Int
d Bool
False (NameInfo (Aliases lore) -> Entry lore)
-> NameInfo (Aliases lore) -> Entry lore
forall a b. (a -> b) -> a -> b
$ dec -> NameInfo (Aliases lore)
f (dec -> NameInfo (Aliases lore)) -> dec -> NameInfo (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Param dec -> dec
forall dec. Param dec -> dec
paramDec Param dec
fparam
        )
      entries :: VTable lore
entries = [(VName, Entry lore)] -> VTable lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Entry lore)] -> VTable lore)
-> [(VName, Entry lore)] -> VTable lore
forall a b. (a -> b) -> a -> b
$ (Param dec -> (VName, Entry lore))
-> [Param dec] -> [(VName, Entry lore)]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> (VName, Entry lore)
entry [Param dec]
params
   in Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VTable lore -> VTable lore -> VTable lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable lore
entries VTable lore
vtable) Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y

bindingFParams ::
  [FParam (Aliases lore)] ->
  ForwardingM lore a ->
  ForwardingM lore a
bindingFParams :: [FParam (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
bindingFParams = (FParamInfo lore -> NameInfo (Aliases lore))
-> [Param (FParamInfo lore)]
-> ForwardingM lore a
-> ForwardingM lore a
forall dec lore a.
(dec -> NameInfo (Aliases lore))
-> [Param dec] -> ForwardingM lore a -> ForwardingM lore a
bindingParams FParamInfo lore -> NameInfo (Aliases lore)
forall lore. FParamInfo lore -> NameInfo lore
FParamName

bindingScope ::
  Scope (Aliases lore) ->
  ForwardingM lore a ->
  ForwardingM lore a
bindingScope :: Scope (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingScope Scope (Aliases lore)
scope = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable lore
vtable Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y) ->
  let entries :: VTable lore
entries = (NameInfo (Aliases lore) -> Entry lore)
-> Scope (Aliases lore) -> VTable lore
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo (Aliases lore) -> Entry lore
entry Scope (Aliases lore)
scope
      infoAliases :: NameInfo lore -> Names
infoAliases (LetName (aliases, _)) = VarAliases -> Names
unAliases VarAliases
aliases
      infoAliases NameInfo lore
_ = Names
forall a. Monoid a => a
mempty
      entry :: NameInfo (Aliases lore) -> Entry lore
entry NameInfo (Aliases lore)
info = Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
forall lore.
Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
Entry Int
n (NameInfo (Aliases lore) -> Names
forall lore b.
(LetDec lore ~ (VarAliases, b)) =>
NameInfo lore -> Names
infoAliases NameInfo (Aliases lore)
info) Int
d Bool
False NameInfo (Aliases lore)
info
   in Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VTable lore
entries VTable lore -> VTable lore -> VTable lore
forall a. Semigroup a => a -> a -> a
<> VTable lore
vtable) Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y

bindingStm ::
  Stm (Aliases lore) ->
  ForwardingM lore a ->
  ForwardingM lore a
bindingStm :: Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingStm (Let Pattern (Aliases lore)
pat StmAux (ExpDec (Aliases lore))
_ Exp (Aliases lore)
_) = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable lore
vtable Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y) ->
  let entries :: VTable lore
entries = [(VName, Entry lore)] -> VTable lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Entry lore)] -> VTable lore)
-> [(VName, Entry lore)] -> VTable lore
forall a b. (a -> b) -> a -> b
$ (PatElemT (VarAliases, LetDec lore) -> (VName, Entry lore))
-> [PatElemT (VarAliases, LetDec lore)] -> [(VName, Entry lore)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (VarAliases, LetDec lore) -> (VName, Entry lore)
entry ([PatElemT (VarAliases, LetDec lore)] -> [(VName, Entry lore)])
-> [PatElemT (VarAliases, LetDec lore)] -> [(VName, Entry lore)]
forall a b. (a -> b) -> a -> b
$ PatternT (VarAliases, LetDec lore)
-> [PatElemT (VarAliases, LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (VarAliases, LetDec lore)
Pattern (Aliases lore)
pat
      entry :: PatElemT (VarAliases, LetDec lore) -> (VName, Entry lore)
entry PatElemT (VarAliases, LetDec lore)
patElem =
        let (VarAliases
aliases, LetDec lore
_) = PatElemT (VarAliases, LetDec lore) -> (VarAliases, LetDec lore)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (VarAliases, LetDec lore)
patElem
         in ( PatElemT (VarAliases, LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, LetDec lore)
patElem,
              Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
forall lore.
Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
Entry Int
n (VarAliases -> Names
unAliases VarAliases
aliases) Int
d Bool
True (NameInfo (Aliases lore) -> Entry lore)
-> NameInfo (Aliases lore) -> Entry lore
forall a b. (a -> b) -> a -> b
$ LetDec (Aliases lore) -> NameInfo (Aliases lore)
forall lore. LetDec lore -> NameInfo lore
LetName (LetDec (Aliases lore) -> NameInfo (Aliases lore))
-> LetDec (Aliases lore) -> NameInfo (Aliases lore)
forall a b. (a -> b) -> a -> b
$ PatElemT (VarAliases, LetDec lore) -> (VarAliases, LetDec lore)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (VarAliases, LetDec lore)
patElem
            )
   in Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VTable lore -> VTable lore -> VTable lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable lore
entries VTable lore
vtable) Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y

bindingNumber :: VName -> ForwardingM lore Int
bindingNumber :: VName -> ForwardingM lore Int
bindingNumber VName
name = do
  Maybe Int
res <- (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int))
-> (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Entry lore -> Int) -> Maybe (Entry lore) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry lore -> Int
forall lore. Entry lore -> Int
entryNumber (Maybe (Entry lore) -> Maybe Int)
-> (TopDown lore -> Maybe (Entry lore))
-> TopDown lore
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  case Maybe Int
res of
    Just Int
n -> Int -> ForwardingM lore Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
    Maybe Int
Nothing ->
      String -> ForwardingM lore Int
forall a. HasCallStack => String -> a
error (String -> ForwardingM lore Int) -> String -> ForwardingM lore Int
forall a b. (a -> b) -> a -> b
$
        String
"bindingNumber: variable "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

deepen :: ForwardingM lore a -> ForwardingM lore a
deepen :: ForwardingM lore a -> ForwardingM lore a
deepen = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \TopDown lore
env -> TopDown lore
env {topDownDepth :: Int
topDownDepth = TopDown lore -> Int
forall lore. TopDown lore -> Int
topDownDepth TopDown lore
env Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

areAvailableBefore :: Names -> VName -> ForwardingM lore Bool
areAvailableBefore :: Names -> VName -> ForwardingM lore Bool
areAvailableBefore Names
names VName
point = do
  Int
pointN <- VName -> ForwardingM lore Int
forall lore. VName -> ForwardingM lore Int
bindingNumber VName
point
  [Int]
nameNs <- (VName -> ForwardingM lore Int)
-> [VName] -> ForwardingM lore [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ForwardingM lore Int
forall lore. VName -> ForwardingM lore Int
bindingNumber ([VName] -> ForwardingM lore [Int])
-> [VName] -> ForwardingM lore [Int]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names
  Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> ForwardingM lore Bool) -> Bool -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
pointN) [Int]
nameNs

isInCurrentBody :: VName -> ForwardingM lore Bool
isInCurrentBody :: VName -> ForwardingM lore Bool
isInCurrentBody VName
name = do
  Int
current <- (TopDown lore -> Int) -> ForwardingM lore Int
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore -> Int
forall lore. TopDown lore -> Int
topDownDepth
  Maybe Int
res <- (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int))
-> (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Entry lore -> Int) -> Maybe (Entry lore) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry lore -> Int
forall lore. Entry lore -> Int
entryDepth (Maybe (Entry lore) -> Maybe Int)
-> (TopDown lore -> Maybe (Entry lore))
-> TopDown lore
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  case Maybe Int
res of
    Just Int
d -> Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> ForwardingM lore Bool) -> Bool -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$ Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
current
    Maybe Int
Nothing ->
      String -> ForwardingM lore Bool
forall a. HasCallStack => String -> a
error (String -> ForwardingM lore Bool)
-> String -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$
        String
"isInCurrentBody: variable "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

isOptimisable :: VName -> ForwardingM lore Bool
isOptimisable :: VName -> ForwardingM lore Bool
isOptimisable VName
name = do
  Maybe Bool
res <- (TopDown lore -> Maybe Bool) -> ForwardingM lore (Maybe Bool)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Maybe Bool) -> ForwardingM lore (Maybe Bool))
-> (TopDown lore -> Maybe Bool) -> ForwardingM lore (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ (Entry lore -> Bool) -> Maybe (Entry lore) -> Maybe Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry lore -> Bool
forall lore. Entry lore -> Bool
entryOptimisable (Maybe (Entry lore) -> Maybe Bool)
-> (TopDown lore -> Maybe (Entry lore))
-> TopDown lore
-> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  case Maybe Bool
res of
    Just Bool
b -> Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
b
    Maybe Bool
Nothing ->
      String -> ForwardingM lore Bool
forall a. HasCallStack => String -> a
error (String -> ForwardingM lore Bool)
-> String -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$
        String
"isOptimisable: variable "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

seenVar :: VName -> ForwardingM lore ()
seenVar :: VName -> ForwardingM lore ()
seenVar VName
name = do
  Names
aliases <-
    (TopDown lore -> Names) -> ForwardingM lore Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Names) -> ForwardingM lore Names)
-> (TopDown lore -> Names) -> ForwardingM lore Names
forall a b. (a -> b) -> a -> b
$
      Names -> (Entry lore -> Names) -> Maybe (Entry lore) -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty Entry lore -> Names
forall lore. Entry lore -> Names
entryAliases
        (Maybe (Entry lore) -> Names)
-> (TopDown lore -> Maybe (Entry lore)) -> TopDown lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name
        (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  BottomUp lore -> ForwardingM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (BottomUp lore -> ForwardingM lore ())
-> BottomUp lore -> ForwardingM lore ()
forall a b. (a -> b) -> a -> b
$ BottomUp lore
forall a. Monoid a => a
mempty {bottomUpSeen :: Names
bottomUpSeen = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases}

tapBottomUp :: ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp :: ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp ForwardingM lore a
m = do
  (a
x, BottomUp lore
bup) <- ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ForwardingM lore a
m
  (a, BottomUp lore) -> ForwardingM lore (a, BottomUp lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, BottomUp lore
bup)

maybeForward ::
  Constraints lore =>
  VName ->
  VName ->
  LetDec (Aliases lore) ->
  Certificates ->
  VName ->
  Slice SubExp ->
  ForwardingM lore ()
maybeForward :: VName
-> VName
-> LetDec (Aliases lore)
-> Certificates
-> VName
-> Slice SubExp
-> ForwardingM lore ()
maybeForward VName
v VName
dest_nm LetDec (Aliases lore)
dest_dec Certificates
cs VName
src Slice SubExp
slice = do
  -- Checks condition (2)
  Bool
available <-
    (VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
src Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Certificates -> Names
forall a. FreeIn a => a -> Names
freeIn Certificates
cs)
      Names -> VName -> ForwardingM lore Bool
forall lore. Names -> VName -> ForwardingM lore Bool
`areAvailableBefore` VName
v
  -- Check condition (3)
  Bool
samebody <- VName -> ForwardingM lore Bool
forall lore. VName -> ForwardingM lore Bool
isInCurrentBody VName
v
  -- Check condition (6)
  Bool
optimisable <- VName -> ForwardingM lore Bool
forall lore. VName -> ForwardingM lore Bool
isOptimisable VName
v
  Bool
not_prim <- Bool -> Bool
not (Bool -> Bool) -> (Type -> Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> ForwardingM lore Type -> ForwardingM lore Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ForwardingM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  Bool -> ForwardingM lore () -> ForwardingM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
available Bool -> Bool -> Bool
&& Bool
samebody Bool -> Bool -> Bool
&& Bool
optimisable Bool -> Bool -> Bool
&& Bool
not_prim) (ForwardingM lore () -> ForwardingM lore ())
-> ForwardingM lore () -> ForwardingM lore ()
forall a b. (a -> b) -> a -> b
$ do
    let fwd :: DesiredUpdate (VarAliases, LetDec lore)
fwd = VName
-> (VarAliases, LetDec lore)
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DesiredUpdate (VarAliases, LetDec lore)
forall dec.
VName
-> dec
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DesiredUpdate dec
DesiredUpdate VName
dest_nm (VarAliases, LetDec lore)
LetDec (Aliases lore)
dest_dec Certificates
cs VName
src Slice SubExp
slice VName
v
    BottomUp lore -> ForwardingM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp Any
forall a. Monoid a => a
mempty {forwardThese :: [DesiredUpdate (LetDec (Aliases lore))]
forwardThese = [DesiredUpdate (VarAliases, LetDec lore)
DesiredUpdate (LetDec (Aliases lore))
fwd]}