{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module implements an optimisation that moves in-place
-- updates into/before loops where possible, with the end goal of
-- minimising memory copies.  As an example, consider this program:
--
-- @
--   let r =
--     loop (r1 = r0) = for i < n do
--       let a = r1[i]
--       let r1[i] = a * i
--       in r1
--   ...
--   let x = y with [k] <- r in
--   ...
-- @
--
-- We want to turn this into the following:
--
-- @
--   let x0 = y with [k] <- r0
--   loop (x = x0) = for i < n do
--     let a = a[k,i]
--     let x[k,i] = a * i
--     in x
--   let r = x[k] in
--   ...
-- @
--
-- The intent is that we are also going to optimise the new data
-- movement (in the @x0@-binding), possibly by changing how @r0@ is
-- defined.  For the above transformation to be valid, a number of
-- conditions must be fulfilled:
--
--    (1) @r@ must not be consumed after the original in-place update.
--
--    (2) @k@ and @y@ must be available at the beginning of the loop.
--
--    (3) @x@ must be visible whenever @r@ is visible.  (This means
--    that both @x@ and @r@ must be bound in the same t'Body'.)
--
--    (4) If @x@ is consumed at a point after the loop, @r@ must not
--    be used after that point.
--
--    (5) The size of @r1@ is invariant inside the loop.
--
--    (6) The value @r@ must come from something that we can actually
--    optimise (e.g. not a function parameter).
--
--    (7) @y@ (or its aliases) may not be used inside the body of the
--    loop.
--
--    (8) The result of the loop may not alias the merge parameter
--    @r1@.
--
--    (9) @y@ or its aliases may not be used after the loop.
--
-- FIXME: the implementation is not finished yet.  Specifically, not
-- all of the above conditions are checked.
module Futhark.Optimise.InPlaceLowering
  ( inPlaceLoweringGPU,
    inPlaceLoweringSeq,
    inPlaceLoweringMC,
  )
where

import Control.Monad.RWS
import qualified Data.Map.Strict as M
import Data.Ord (comparing)
import Futhark.Analysis.Alias
import Futhark.Builder
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.IR.Seq (Seq)
import Futhark.Optimise.InPlaceLowering.LowerIntoStm
import Futhark.Pass
import Futhark.Util (nubByOrd)

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringGPU :: Pass GPU GPU
inPlaceLoweringGPU :: Pass GPU GPU
inPlaceLoweringGPU = OnOp GPU -> LowerUpdate GPU (ForwardingM GPU) -> Pass GPU GPU
forall rep.
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp GPU
onKernelOp LowerUpdate GPU (ForwardingM GPU)
forall (m :: * -> *). MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringSeq :: Pass Seq Seq
inPlaceLoweringSeq :: Pass Seq Seq
inPlaceLoweringSeq = OnOp Seq -> LowerUpdate Seq (ForwardingM Seq) -> Pass Seq Seq
forall rep.
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp Seq
forall (f :: * -> *) a. Applicative f => a -> f a
pure LowerUpdate Seq (ForwardingM Seq)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 CanBeAliased (Op rep)) =>
LowerUpdate rep m
lowerUpdate

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLoweringMC :: Pass MC MC
inPlaceLoweringMC :: Pass MC MC
inPlaceLoweringMC = OnOp MC -> LowerUpdate MC (ForwardingM MC) -> Pass MC MC
forall rep.
Constraints rep =>
OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp MC
onMCOp LowerUpdate MC (ForwardingM MC)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 CanBeAliased (Op rep)) =>
LowerUpdate rep m
lowerUpdate

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLowering ::
  Constraints rep =>
  OnOp rep ->
  LowerUpdate rep (ForwardingM rep) ->
  Pass rep rep
inPlaceLowering :: OnOp rep -> LowerUpdate rep (ForwardingM rep) -> Pass rep rep
inPlaceLowering OnOp rep
onOp LowerUpdate rep (ForwardingM rep)
lower =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"In-place lowering" String
"Lower in-place updates into loops" ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$
    (Prog (Aliases rep) -> Prog rep)
-> PassM (Prog (Aliases rep)) -> PassM (Prog rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases rep) -> Prog rep
forall rep. CanBeAliased (Op rep) => Prog (Aliases rep) -> Prog rep
removeProgAliases
      (PassM (Prog (Aliases rep)) -> PassM (Prog rep))
-> (Prog rep -> PassM (Prog (Aliases rep)))
-> Prog rep
-> PassM (Prog rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (Aliases rep) -> PassM (Stms (Aliases rep)))
-> (Stms (Aliases rep)
    -> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep)))
-> Prog (Aliases rep)
-> PassM (Prog (Aliases rep))
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases rep) -> PassM (Stms (Aliases rep))
optimiseConsts Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
optimiseFunDef
      (Prog (Aliases rep) -> PassM (Prog (Aliases rep)))
-> (Prog rep -> Prog (Aliases rep))
-> Prog rep
-> PassM (Prog (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog rep -> Prog (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
aliasAnalysis
  where
    optimiseConsts :: Stms (Aliases rep) -> PassM (Stms (Aliases rep))
optimiseConsts Stms (Aliases rep)
stms =
      (VNameSource -> (Stms (Aliases rep), VNameSource))
-> PassM (Stms (Aliases rep))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms (Aliases rep), VNameSource))
 -> PassM (Stms (Aliases rep)))
-> (VNameSource -> (Stms (Aliases rep), VNameSource))
-> PassM (Stms (Aliases rep))
forall a b. (a -> b) -> a -> b
$
        LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> ForwardingM rep (Stms (Aliases rep))
-> VNameSource
-> (Stms (Aliases rep), VNameSource)
forall rep a.
LowerUpdate rep (ForwardingM rep)
-> OnOp rep -> ForwardingM rep a -> VNameSource -> (a, VNameSource)
runForwardingM LowerUpdate rep (ForwardingM rep)
lower OnOp rep
onOp (ForwardingM rep (Stms (Aliases rep))
 -> VNameSource -> (Stms (Aliases rep), VNameSource))
-> ForwardingM rep (Stms (Aliases rep))
-> VNameSource
-> (Stms (Aliases rep), VNameSource)
forall a b. (a -> b) -> a -> b
$
          [Stm (Aliases rep)] -> Stms (Aliases rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm (Aliases rep)] -> Stms (Aliases rep))
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep (Stms (Aliases rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall rep.
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Aliases rep)
stms) (() -> ForwardingM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

    optimiseFunDef :: Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
optimiseFunDef Stms (Aliases rep)
consts FunDef (Aliases rep)
fundec =
      (VNameSource -> (FunDef (Aliases rep), VNameSource))
-> PassM (FunDef (Aliases rep))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (FunDef (Aliases rep), VNameSource))
 -> PassM (FunDef (Aliases rep)))
-> (VNameSource -> (FunDef (Aliases rep), VNameSource))
-> PassM (FunDef (Aliases rep))
forall a b. (a -> b) -> a -> b
$
        LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> ForwardingM rep (FunDef (Aliases rep))
-> VNameSource
-> (FunDef (Aliases rep), VNameSource)
forall rep a.
LowerUpdate rep (ForwardingM rep)
-> OnOp rep -> ForwardingM rep a -> VNameSource -> (a, VNameSource)
runForwardingM LowerUpdate rep (ForwardingM rep)
lower OnOp rep
onOp (ForwardingM rep (FunDef (Aliases rep))
 -> VNameSource -> (FunDef (Aliases rep), VNameSource))
-> ForwardingM rep (FunDef (Aliases rep))
-> VNameSource
-> (FunDef (Aliases rep), VNameSource)
forall a b. (a -> b) -> a -> b
$
          [Stm (Aliases rep)]
-> ForwardingM rep (FunDef (Aliases rep))
-> ForwardingM rep (FunDef (Aliases rep))
forall rep a.
[Stm (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
descend (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Aliases rep)
consts) (ForwardingM rep (FunDef (Aliases rep))
 -> ForwardingM rep (FunDef (Aliases rep)))
-> ForwardingM rep (FunDef (Aliases rep))
-> ForwardingM rep (FunDef (Aliases rep))
forall a b. (a -> b) -> a -> b
$
            [FParam (Aliases rep)]
-> ForwardingM rep (FunDef (Aliases rep))
-> ForwardingM rep (FunDef (Aliases rep))
forall rep a.
[FParam (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
bindingFParams (FunDef (Aliases rep) -> [FParam (Aliases rep)]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef (Aliases rep)
fundec) (ForwardingM rep (FunDef (Aliases rep))
 -> ForwardingM rep (FunDef (Aliases rep)))
-> ForwardingM rep (FunDef (Aliases rep))
-> ForwardingM rep (FunDef (Aliases rep))
forall a b. (a -> b) -> a -> b
$ do
              Body (Aliases rep)
body <- Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
forall rep.
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody (Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep)))
-> Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases rep) -> Body (Aliases rep)
forall rep. FunDef rep -> Body rep
funDefBody FunDef (Aliases rep)
fundec
              FunDef (Aliases rep) -> ForwardingM rep (FunDef (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef (Aliases rep) -> ForwardingM rep (FunDef (Aliases rep)))
-> FunDef (Aliases rep) -> ForwardingM rep (FunDef (Aliases rep))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases rep)
fundec {funDefBody :: Body (Aliases rep)
funDefBody = Body (Aliases rep)
body}

    descend :: [Stm (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
descend [] ForwardingM rep a
m = ForwardingM rep a
m
    descend (Stm (Aliases rep)
stm : [Stm (Aliases rep)]
stms) ForwardingM rep a
m = Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
forall rep a.
Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingStm Stm (Aliases rep)
stm (ForwardingM rep a -> ForwardingM rep a)
-> ForwardingM rep a -> ForwardingM rep a
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
descend [Stm (Aliases rep)]
stms ForwardingM rep a
m

type Constraints rep = (Buildable rep, CanBeAliased (Op rep))

optimiseBody ::
  Constraints rep =>
  Body (Aliases rep) ->
  ForwardingM rep (Body (Aliases rep))
optimiseBody :: Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody (Body BodyDec (Aliases rep)
als Stms (Aliases rep)
stms Result
res) = do
  [Stm (Aliases rep)]
stms' <- ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall rep a. ForwardingM rep a -> ForwardingM rep a
deepen (ForwardingM rep [Stm (Aliases rep)]
 -> ForwardingM rep [Stm (Aliases rep)])
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall rep.
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Aliases rep)
stms) (ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)])
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> ForwardingM rep ()) -> Result -> ForwardingM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> ForwardingM rep ()
forall rep. SubExp -> ForwardingM rep ()
seen (SubExp -> ForwardingM rep ())
-> (SubExpRes -> SubExp) -> SubExpRes -> ForwardingM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
  Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep)))
-> Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
forall a b. (a -> b) -> a -> b
$ BodyDec (Aliases rep)
-> Stms (Aliases rep) -> Result -> Body (Aliases rep)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Aliases rep)
als ([Stm (Aliases rep)] -> Stms (Aliases rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm (Aliases rep)]
stms') Result
res
  where
    seen :: SubExp -> ForwardingM rep ()
seen Constant {} = () -> ForwardingM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    seen (Var VName
v) = VName -> ForwardingM rep ()
forall rep. VName -> ForwardingM rep ()
seenVar VName
v

optimiseStms ::
  Constraints rep =>
  [Stm (Aliases rep)] ->
  ForwardingM rep () ->
  ForwardingM rep [Stm (Aliases rep)]
optimiseStms :: [Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms [] ForwardingM rep ()
m = ForwardingM rep ()
m ForwardingM rep ()
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
optimiseStms (Stm (Aliases rep)
stm : [Stm (Aliases rep)]
stms) ForwardingM rep ()
m = do
  ([Stm (Aliases rep)]
stms', BottomUp rep
bup) <- ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep ([Stm (Aliases rep)], BottomUp rep)
forall rep a.
ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
tapBottomUp (ForwardingM rep [Stm (Aliases rep)]
 -> ForwardingM rep ([Stm (Aliases rep)], BottomUp rep))
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep ([Stm (Aliases rep)], BottomUp rep)
forall a b. (a -> b) -> a -> b
$ Stm (Aliases rep)
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall rep a.
Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingStm Stm (Aliases rep)
stm (ForwardingM rep [Stm (Aliases rep)]
 -> ForwardingM rep [Stm (Aliases rep)])
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall rep.
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms [Stm (Aliases rep)]
stms ForwardingM rep ()
m
  Stm (Aliases rep)
stm' <- Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
forall rep.
Constraints rep =>
Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
optimiseInStm Stm (Aliases rep)
stm
  -- XXX: unfortunate that we cannot handle duplicate update values.
  -- Would be good to improve this.  See inplacelowering6.fut.
  case (DesiredUpdate (VarAliases, LetDec rep)
 -> DesiredUpdate (VarAliases, LetDec rep) -> Ordering)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
nubByOrd ((DesiredUpdate (VarAliases, LetDec rep) -> VName)
-> DesiredUpdate (VarAliases, LetDec rep)
-> DesiredUpdate (VarAliases, LetDec rep)
-> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing DesiredUpdate (VarAliases, LetDec rep) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue)
    ([DesiredUpdate (VarAliases, LetDec rep)]
 -> [DesiredUpdate (VarAliases, LetDec rep)])
-> ([DesiredUpdate (VarAliases, LetDec rep)]
    -> [DesiredUpdate (VarAliases, LetDec rep)])
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DesiredUpdate (VarAliases, LetDec rep) -> Bool)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`notNameIn` BottomUp rep -> Names
forall rep. BottomUp rep -> Names
bottomUpSeen BottomUp rep
bup) (VName -> Bool)
-> (DesiredUpdate (VarAliases, LetDec rep) -> VName)
-> DesiredUpdate (VarAliases, LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, LetDec rep) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource) -- (9)
    ([DesiredUpdate (VarAliases, LetDec rep)]
 -> [DesiredUpdate (VarAliases, LetDec rep)])
-> ([DesiredUpdate (VarAliases, LetDec rep)]
    -> [DesiredUpdate (VarAliases, LetDec rep)])
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DesiredUpdate (VarAliases, LetDec rep) -> Bool)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
boundHere) (VName -> Bool)
-> (DesiredUpdate (VarAliases, LetDec rep) -> VName)
-> DesiredUpdate (VarAliases, LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, LetDec rep) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue)
    ([DesiredUpdate (VarAliases, LetDec rep)]
 -> [DesiredUpdate (VarAliases, LetDec rep)])
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall a b. (a -> b) -> a -> b
$ BottomUp rep -> [DesiredUpdate (LetDec (Aliases rep))]
forall rep. BottomUp rep -> [DesiredUpdate (LetDec (Aliases rep))]
forwardThese BottomUp rep
bup of
    [] -> do
      Stm (Aliases rep) -> ForwardingM rep ()
forall rep rep.
(Buildable rep, CanBeAliased (Op rep), FreeDec (ExpDec rep),
 FreeDec (BodyDec rep), FreeIn (FParamInfo rep),
 FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep), FreeIn (Op rep),
 LetDec rep ~ (VarAliases, LetDec rep)) =>
Stm rep -> ForwardingM rep ()
checkIfForwardableUpdate Stm (Aliases rep)
stm'
      [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)])
-> [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases rep)
stm' Stm (Aliases rep) -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. a -> [a] -> [a]
: [Stm (Aliases rep)]
stms'
    [DesiredUpdate (VarAliases, LetDec rep)]
updates -> do
      Scope (Aliases rep)
-> Stm (Aliases rep)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> Maybe (ForwardingM rep [Stm (Aliases rep)])
lower <- (TopDown rep
 -> Scope (Aliases rep)
 -> Stm (Aliases rep)
 -> [DesiredUpdate (VarAliases, LetDec rep)]
 -> Maybe (ForwardingM rep [Stm (Aliases rep)]))
-> ForwardingM
     rep
     (Scope (Aliases rep)
      -> Stm (Aliases rep)
      -> [DesiredUpdate (VarAliases, LetDec rep)]
      -> Maybe (ForwardingM rep [Stm (Aliases rep)]))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown rep
-> Scope (Aliases rep)
-> Stm (Aliases rep)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> Maybe (ForwardingM rep [Stm (Aliases rep)])
forall rep. TopDown rep -> LowerUpdate rep (ForwardingM rep)
topLowerUpdate
      Scope (Aliases rep)
scope <- ForwardingM rep (Scope (Aliases rep))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope

      -- If we forward any updates, we need to remove them from stms'.
      let updated_names :: [VName]
updated_names =
            (DesiredUpdate (VarAliases, LetDec rep) -> VName)
-> [DesiredUpdate (VarAliases, LetDec rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map DesiredUpdate (VarAliases, LetDec rep) -> VName
forall dec. DesiredUpdate dec -> VName
updateName [DesiredUpdate (VarAliases, LetDec rep)]
updates
          notUpdated :: Stm (Aliases rep) -> Bool
notUpdated =
            Bool -> Bool
not (Bool -> Bool)
-> (Stm (Aliases rep) -> Bool) -> Stm (Aliases rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
updated_names) ([VName] -> Bool)
-> (Stm (Aliases rep) -> [VName]) -> Stm (Aliases rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (VarAliases, LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (VarAliases, LetDec rep) -> [VName])
-> (Stm (Aliases rep) -> Pat (VarAliases, LetDec rep))
-> Stm (Aliases rep)
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Aliases rep) -> Pat (VarAliases, LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat

      -- Condition (5) and (7) are assumed to be checked by
      -- lowerUpdate.
      case Scope (Aliases rep)
-> Stm (Aliases rep)
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> Maybe (ForwardingM rep [Stm (Aliases rep)])
lower Scope (Aliases rep)
scope Stm (Aliases rep)
stm' [DesiredUpdate (VarAliases, LetDec rep)]
updates of
        Just ForwardingM rep [Stm (Aliases rep)]
lowering -> do
          [Stm (Aliases rep)]
new_stms <- ForwardingM rep [Stm (Aliases rep)]
lowering
          [Stm (Aliases rep)]
new_stms' <- [Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall rep.
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms [Stm (Aliases rep)]
new_stms (ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)])
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ BottomUp rep -> ForwardingM rep ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp rep
bup {forwardThese :: [DesiredUpdate (LetDec (Aliases rep))]
forwardThese = []}
          [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)])
-> [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases rep)]
new_stms' [Stm (Aliases rep)] -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. [a] -> [a] -> [a]
++ (Stm (Aliases rep) -> Bool)
-> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. (a -> Bool) -> [a] -> [a]
filter Stm (Aliases rep) -> Bool
notUpdated [Stm (Aliases rep)]
stms'
        Maybe (ForwardingM rep [Stm (Aliases rep)])
Nothing -> do
          Stm (Aliases rep) -> ForwardingM rep ()
forall rep rep.
(Buildable rep, CanBeAliased (Op rep), FreeDec (ExpDec rep),
 FreeDec (BodyDec rep), FreeIn (FParamInfo rep),
 FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep), FreeIn (Op rep),
 LetDec rep ~ (VarAliases, LetDec rep)) =>
Stm rep -> ForwardingM rep ()
checkIfForwardableUpdate Stm (Aliases rep)
stm'
          [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)])
-> [Stm (Aliases rep)] -> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases rep)
stm' Stm (Aliases rep) -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. a -> [a] -> [a]
: [Stm (Aliases rep)]
stms'
  where
    boundHere :: [VName]
boundHere = Pat (VarAliases, LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (VarAliases, LetDec rep) -> [VName])
-> Pat (VarAliases, LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases rep) -> Pat (LetDec (Aliases rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm (Aliases rep)
stm

    checkIfForwardableUpdate :: Stm rep -> ForwardingM rep ()
checkIfForwardableUpdate (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
_ ExpDec rep
_) Exp rep
e)
      | Pat [PatElem VName
v LetDec rep
dec] <- Pat (LetDec rep)
pat,
        BasicOp (Update Safety
Unsafe VName
src Slice SubExp
slice (Var VName
ve)) <- Exp rep
e =
          VName
-> VName
-> LetDec (Aliases rep)
-> Certs
-> VName
-> Slice SubExp
-> ForwardingM rep ()
forall rep.
Constraints rep =>
VName
-> VName
-> LetDec (Aliases rep)
-> Certs
-> VName
-> Slice SubExp
-> ForwardingM rep ()
maybeForward VName
ve VName
v LetDec rep
LetDec (Aliases rep)
dec Certs
cs VName
src Slice SubExp
slice
    checkIfForwardableUpdate Stm rep
stm' =
      (VName -> ForwardingM rep ()) -> [VName] -> ForwardingM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> ForwardingM rep ()
forall rep. VName -> ForwardingM rep ()
seenVar ([VName] -> ForwardingM rep ()) -> [VName] -> ForwardingM rep ()
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Exp rep -> Names
forall a. FreeIn a => a -> Names
freeIn (Exp rep -> Names) -> Exp rep -> Names
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm'

optimiseInStm :: Constraints rep => Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
optimiseInStm :: Stm (Aliases rep) -> ForwardingM rep (Stm (Aliases rep))
optimiseInStm (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
dec Exp (Aliases rep)
e) =
  Pat (LetDec (Aliases rep))
-> StmAux (ExpDec (Aliases rep))
-> Exp (Aliases rep)
-> Stm (Aliases rep)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
dec (Exp (Aliases rep) -> Stm (Aliases rep))
-> ForwardingM rep (Exp (Aliases rep))
-> ForwardingM rep (Stm (Aliases rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
forall rep.
Constraints rep =>
Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
optimiseExp Exp (Aliases rep)
e

optimiseExp :: Constraints rep => Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
optimiseExp :: Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
optimiseExp (DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body) =
  Scope (Aliases rep)
-> ForwardingM rep (Exp (Aliases rep))
-> ForwardingM rep (Exp (Aliases rep))
forall rep a.
Scope (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingScope (LoopForm (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Aliases rep)
form) (ForwardingM rep (Exp (Aliases rep))
 -> ForwardingM rep (Exp (Aliases rep)))
-> (ForwardingM rep (Exp (Aliases rep))
    -> ForwardingM rep (Exp (Aliases rep)))
-> ForwardingM rep (Exp (Aliases rep))
-> ForwardingM rep (Exp (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [FParam (Aliases rep)]
-> ForwardingM rep (Exp (Aliases rep))
-> ForwardingM rep (Exp (Aliases rep))
forall rep a.
[FParam (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
bindingFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam (Aliases rep), SubExp)]
merge) (ForwardingM rep (Exp (Aliases rep))
 -> ForwardingM rep (Exp (Aliases rep)))
-> ForwardingM rep (Exp (Aliases rep))
-> ForwardingM rep (Exp (Aliases rep))
forall a b. (a -> b) -> a -> b
$
    [(FParam (Aliases rep), SubExp)]
-> LoopForm (Aliases rep)
-> Body (Aliases rep)
-> Exp (Aliases rep)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form (Body (Aliases rep) -> Exp (Aliases rep))
-> ForwardingM rep (Body (Aliases rep))
-> ForwardingM rep (Exp (Aliases rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
forall rep.
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody Body (Aliases rep)
body
optimiseExp (Op Op (Aliases rep)
op) = do
  OpWithAliases (Op rep) -> ForwardingM rep (OpWithAliases (Op rep))
f <- (TopDown rep
 -> OpWithAliases (Op rep)
 -> ForwardingM rep (OpWithAliases (Op rep)))
-> ForwardingM
     rep
     (OpWithAliases (Op rep)
      -> ForwardingM rep (OpWithAliases (Op rep)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown rep
-> OpWithAliases (Op rep)
-> ForwardingM rep (OpWithAliases (Op rep))
forall rep. TopDown rep -> OnOp rep
topOnOp
  OpWithAliases (Op rep) -> Exp (Aliases rep)
forall rep. Op rep -> Exp rep
Op (OpWithAliases (Op rep) -> Exp (Aliases rep))
-> ForwardingM rep (OpWithAliases (Op rep))
-> ForwardingM rep (Exp (Aliases rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpWithAliases (Op rep) -> ForwardingM rep (OpWithAliases (Op rep))
f Op (Aliases rep)
OpWithAliases (Op rep)
op
optimiseExp Exp (Aliases rep)
e = Mapper (Aliases rep) (Aliases rep) (ForwardingM rep)
-> Exp (Aliases rep) -> ForwardingM rep (Exp (Aliases rep))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper (Aliases rep) (Aliases rep) (ForwardingM rep)
optimise Exp (Aliases rep)
e
  where
    optimise :: Mapper (Aliases rep) (Aliases rep) (ForwardingM rep)
optimise =
      Mapper (Aliases rep) (Aliases rep) (ForwardingM rep)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope (Aliases rep)
-> Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
mapOnBody = (Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep)))
-> Scope (Aliases rep)
-> Body (Aliases rep)
-> ForwardingM rep (Body (Aliases rep))
forall a b. a -> b -> a
const Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
forall rep.
Constraints rep =>
Body (Aliases rep) -> ForwardingM rep (Body (Aliases rep))
optimiseBody
        }

onSegOp ::
  (Buildable rep, CanBeAliased (Op rep)) =>
  SegOp lvl (Aliases rep) ->
  ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp :: SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp SegOp lvl (Aliases rep)
op =
  Scope (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
-> ForwardingM rep (SegOp lvl (Aliases rep))
forall rep a.
Scope (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingScope (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl (Aliases rep) -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl (Aliases rep)
op)) (ForwardingM rep (SegOp lvl (Aliases rep))
 -> ForwardingM rep (SegOp lvl (Aliases rep)))
-> ForwardingM rep (SegOp lvl (Aliases rep))
-> ForwardingM rep (SegOp lvl (Aliases rep))
forall a b. (a -> b) -> a -> b
$ do
    let mapper :: SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
mapper = SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper {mapOnSegOpBody :: KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
mapOnSegOpBody = KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
forall rep.
(Buildable rep, CanBeAliased (Op rep)) =>
KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
onKernelBody}
        onKernelBody :: KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
onKernelBody KernelBody (Aliases rep)
kbody = do
          [Stm (Aliases rep)]
stms <-
            ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall rep a. ForwardingM rep a -> ForwardingM rep a
deepen (ForwardingM rep [Stm (Aliases rep)]
 -> ForwardingM rep [Stm (Aliases rep)])
-> ForwardingM rep [Stm (Aliases rep)]
-> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
              [Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall rep.
Constraints rep =>
[Stm (Aliases rep)]
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
optimiseStms (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kbody)) (ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)])
-> ForwardingM rep () -> ForwardingM rep [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
                (VName -> ForwardingM rep ()) -> [VName] -> ForwardingM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> ForwardingM rep ()
forall rep. VName -> ForwardingM rep ()
seenVar ([VName] -> ForwardingM rep ()) -> [VName] -> ForwardingM rep ()
forall a b. (a -> b) -> a -> b
$
                  Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
                    [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn ([KernelResult] -> Names) -> [KernelResult] -> Names
forall a b. (a -> b) -> a -> b
$
                      KernelBody (Aliases rep) -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases rep)
kbody
          KernelBody (Aliases rep)
-> ForwardingM rep (KernelBody (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody (Aliases rep)
kbody {kernelBodyStms :: Stms (Aliases rep)
kernelBodyStms = [Stm (Aliases rep)] -> Stms (Aliases rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm (Aliases rep)]
stms}
    SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
-> SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
forall lvl.
SegOpMapper lvl (Aliases rep) (Aliases rep) (ForwardingM rep)
mapper SegOp lvl (Aliases rep)
op

onMCOp :: OnOp MC
onMCOp :: OnOp MC
onMCOp (ParOp par_op op) = Maybe (SegOp () (Aliases MC))
-> SegOp () (Aliases MC) -> MCOp (Aliases MC) (SOAC (Aliases MC))
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (Maybe (SegOp () (Aliases MC))
 -> SegOp () (Aliases MC) -> MCOp (Aliases MC) (SOAC (Aliases MC)))
-> ForwardingM MC (Maybe (SegOp () (Aliases MC)))
-> ForwardingM
     MC (SegOp () (Aliases MC) -> MCOp (Aliases MC) (SOAC (Aliases MC)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () (Aliases MC) -> ForwardingM MC (SegOp () (Aliases MC)))
-> Maybe (SegOp () (Aliases MC))
-> ForwardingM MC (Maybe (SegOp () (Aliases MC)))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SegOp () (Aliases MC) -> ForwardingM MC (SegOp () (Aliases MC))
forall rep lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp Maybe (SegOp () (Aliases MC))
par_op ForwardingM
  MC (SegOp () (Aliases MC) -> MCOp (Aliases MC) (SOAC (Aliases MC)))
-> ForwardingM MC (SegOp () (Aliases MC))
-> ForwardingM MC (MCOp (Aliases MC) (SOAC (Aliases MC)))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOp () (Aliases MC) -> ForwardingM MC (SegOp () (Aliases MC))
forall rep lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp SegOp () (Aliases MC)
op
onMCOp Op (Aliases MC)
op = MCOp (Aliases MC) (SOAC (Aliases MC))
-> ForwardingM MC (MCOp (Aliases MC) (SOAC (Aliases MC)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op (Aliases MC)
MCOp (Aliases MC) (SOAC (Aliases MC))
op

onKernelOp :: OnOp GPU
onKernelOp :: OnOp GPU
onKernelOp (SegOp op) = SegOp SegLevel (Aliases GPU)
-> HostOp (Aliases GPU) (SOAC (Aliases GPU))
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel (Aliases GPU)
 -> HostOp (Aliases GPU) (SOAC (Aliases GPU)))
-> ForwardingM GPU (SegOp SegLevel (Aliases GPU))
-> ForwardingM GPU (HostOp (Aliases GPU) (SOAC (Aliases GPU)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel (Aliases GPU)
-> ForwardingM GPU (SegOp SegLevel (Aliases GPU))
forall rep lvl.
(Buildable rep, CanBeAliased (Op rep)) =>
SegOp lvl (Aliases rep)
-> ForwardingM rep (SegOp lvl (Aliases rep))
onSegOp SegOp SegLevel (Aliases GPU)
op
onKernelOp Op (Aliases GPU)
op = HostOp (Aliases GPU) (SOAC (Aliases GPU))
-> ForwardingM GPU (HostOp (Aliases GPU) (SOAC (Aliases GPU)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op (Aliases GPU)
HostOp (Aliases GPU) (SOAC (Aliases GPU))
op

data Entry rep = Entry
  { Entry rep -> Int
entryNumber :: Int,
    Entry rep -> Names
entryAliases :: Names,
    Entry rep -> Int
entryDepth :: Int,
    Entry rep -> Bool
entryOptimisable :: Bool,
    Entry rep -> NameInfo (Aliases rep)
entryType :: NameInfo (Aliases rep)
  }

type VTable rep = M.Map VName (Entry rep)

type OnOp rep = Op (Aliases rep) -> ForwardingM rep (Op (Aliases rep))

data TopDown rep = TopDown
  { TopDown rep -> Int
topDownCounter :: Int,
    TopDown rep -> VTable rep
topDownTable :: VTable rep,
    TopDown rep -> Int
topDownDepth :: Int,
    TopDown rep -> LowerUpdate rep (ForwardingM rep)
topLowerUpdate :: LowerUpdate rep (ForwardingM rep),
    TopDown rep -> OnOp rep
topOnOp :: OnOp rep
  }

data BottomUp rep = BottomUp
  { BottomUp rep -> Names
bottomUpSeen :: Names,
    BottomUp rep -> [DesiredUpdate (LetDec (Aliases rep))]
forwardThese :: [DesiredUpdate (LetDec (Aliases rep))]
  }

instance Semigroup (BottomUp rep) where
  BottomUp Names
seen1 [DesiredUpdate (LetDec (Aliases rep))]
forward1 <> :: BottomUp rep -> BottomUp rep -> BottomUp rep
<> BottomUp Names
seen2 [DesiredUpdate (LetDec (Aliases rep))]
forward2 =
    Names -> [DesiredUpdate (LetDec (Aliases rep))] -> BottomUp rep
forall rep.
Names -> [DesiredUpdate (LetDec (Aliases rep))] -> BottomUp rep
BottomUp (Names
seen1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
seen2) ([DesiredUpdate (VarAliases, LetDec rep)]
[DesiredUpdate (LetDec (Aliases rep))]
forward1 [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
-> [DesiredUpdate (VarAliases, LetDec rep)]
forall a. Semigroup a => a -> a -> a
<> [DesiredUpdate (VarAliases, LetDec rep)]
[DesiredUpdate (LetDec (Aliases rep))]
forward2)

instance Monoid (BottomUp rep) where
  mempty :: BottomUp rep
mempty = Names -> [DesiredUpdate (LetDec (Aliases rep))] -> BottomUp rep
forall rep.
Names -> [DesiredUpdate (LetDec (Aliases rep))] -> BottomUp rep
BottomUp Names
forall a. Monoid a => a
mempty [DesiredUpdate (LetDec (Aliases rep))]
forall a. Monoid a => a
mempty

newtype ForwardingM rep a = ForwardingM (RWS (TopDown rep) (BottomUp rep) VNameSource a)
  deriving
    ( Applicative (ForwardingM rep)
a -> ForwardingM rep a
Applicative (ForwardingM rep)
-> (forall a b.
    ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b)
-> (forall a b.
    ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b)
-> (forall a. a -> ForwardingM rep a)
-> Monad (ForwardingM rep)
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall rep. Applicative (ForwardingM rep)
forall a. a -> ForwardingM rep a
forall rep a. a -> ForwardingM rep a
forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
forall rep a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall rep a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ForwardingM rep a
$creturn :: forall rep a. a -> ForwardingM rep a
>> :: ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
$c>> :: forall rep a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
>>= :: ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
$c>>= :: forall rep a b.
ForwardingM rep a -> (a -> ForwardingM rep b) -> ForwardingM rep b
$cp1Monad :: forall rep. Applicative (ForwardingM rep)
Monad,
      Functor (ForwardingM rep)
a -> ForwardingM rep a
Functor (ForwardingM rep)
-> (forall a. a -> ForwardingM rep a)
-> (forall a b.
    ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c)
-> (forall a b.
    ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b)
-> (forall a b.
    ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a)
-> Applicative (ForwardingM rep)
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
forall rep. Functor (ForwardingM rep)
forall a. a -> ForwardingM rep a
forall rep a. a -> ForwardingM rep a
forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
forall a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall rep a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
forall rep a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
forall rep a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
forall rep a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
$c<* :: forall rep a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep a
*> :: ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
$c*> :: forall rep a b.
ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep b
liftA2 :: (a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> ForwardingM rep a -> ForwardingM rep b -> ForwardingM rep c
<*> :: ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
$c<*> :: forall rep a b.
ForwardingM rep (a -> b) -> ForwardingM rep a -> ForwardingM rep b
pure :: a -> ForwardingM rep a
$cpure :: forall rep a. a -> ForwardingM rep a
$cp1Applicative :: forall rep. Functor (ForwardingM rep)
Applicative,
      a -> ForwardingM rep b -> ForwardingM rep a
(a -> b) -> ForwardingM rep a -> ForwardingM rep b
(forall a b. (a -> b) -> ForwardingM rep a -> ForwardingM rep b)
-> (forall a b. a -> ForwardingM rep b -> ForwardingM rep a)
-> Functor (ForwardingM rep)
forall a b. a -> ForwardingM rep b -> ForwardingM rep a
forall a b. (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall rep a b. a -> ForwardingM rep b -> ForwardingM rep a
forall rep a b. (a -> b) -> ForwardingM rep a -> ForwardingM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ForwardingM rep b -> ForwardingM rep a
$c<$ :: forall rep a b. a -> ForwardingM rep b -> ForwardingM rep a
fmap :: (a -> b) -> ForwardingM rep a -> ForwardingM rep b
$cfmap :: forall rep a b. (a -> b) -> ForwardingM rep a -> ForwardingM rep b
Functor,
      MonadReader (TopDown rep),
      MonadWriter (BottomUp rep),
      MonadState VNameSource
    )

instance MonadFreshNames (ForwardingM rep) where
  getNameSource :: ForwardingM rep VNameSource
getNameSource = ForwardingM rep VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
  putNameSource :: VNameSource -> ForwardingM rep ()
putNameSource = VNameSource -> ForwardingM rep ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance Constraints rep => HasScope (Aliases rep) (ForwardingM rep) where
  askScope :: ForwardingM rep (Scope (Aliases rep))
askScope = (Entry rep -> NameInfo (Aliases rep))
-> Map VName (Entry rep) -> Scope (Aliases rep)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Entry rep -> NameInfo (Aliases rep)
forall rep. Entry rep -> NameInfo (Aliases rep)
entryType (Map VName (Entry rep) -> Scope (Aliases rep))
-> ForwardingM rep (Map VName (Entry rep))
-> ForwardingM rep (Scope (Aliases rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TopDown rep -> Map VName (Entry rep))
-> ForwardingM rep (Map VName (Entry rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown rep -> Map VName (Entry rep)
forall rep. TopDown rep -> VTable rep
topDownTable

runForwardingM ::
  LowerUpdate rep (ForwardingM rep) ->
  OnOp rep ->
  ForwardingM rep a ->
  VNameSource ->
  (a, VNameSource)
runForwardingM :: LowerUpdate rep (ForwardingM rep)
-> OnOp rep -> ForwardingM rep a -> VNameSource -> (a, VNameSource)
runForwardingM LowerUpdate rep (ForwardingM rep)
f OnOp rep
g (ForwardingM RWS (TopDown rep) (BottomUp rep) VNameSource a
m) VNameSource
src =
  let (a
x, VNameSource
src', BottomUp rep
_) = RWS (TopDown rep) (BottomUp rep) VNameSource a
-> TopDown rep -> VNameSource -> (a, VNameSource, BottomUp rep)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (TopDown rep) (BottomUp rep) VNameSource a
m TopDown rep
emptyTopDown VNameSource
src
   in (a
x, VNameSource
src')
  where
    emptyTopDown :: TopDown rep
emptyTopDown =
      TopDown :: forall rep.
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown
        { topDownCounter :: Int
topDownCounter = Int
0,
          topDownTable :: VTable rep
topDownTable = VTable rep
forall k a. Map k a
M.empty,
          topDownDepth :: Int
topDownDepth = Int
0,
          topLowerUpdate :: LowerUpdate rep (ForwardingM rep)
topLowerUpdate = LowerUpdate rep (ForwardingM rep)
f,
          topOnOp :: OnOp rep
topOnOp = OnOp rep
g
        }

bindingParams ::
  (dec -> NameInfo (Aliases rep)) ->
  [Param dec] ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingParams :: (dec -> NameInfo (Aliases rep))
-> [Param dec] -> ForwardingM rep a -> ForwardingM rep a
bindingParams dec -> NameInfo (Aliases rep)
f [Param dec]
params = (TopDown rep -> TopDown rep)
-> ForwardingM rep a -> ForwardingM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown rep -> TopDown rep)
 -> ForwardingM rep a -> ForwardingM rep a)
-> (TopDown rep -> TopDown rep)
-> ForwardingM rep a
-> ForwardingM rep a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable rep
vtable Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y) ->
  let entry :: Param dec -> (VName, Entry rep)
entry Param dec
fparam =
        ( Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
fparam,
          Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
forall rep.
Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
Entry Int
n Names
forall a. Monoid a => a
mempty Int
d Bool
False (NameInfo (Aliases rep) -> Entry rep)
-> NameInfo (Aliases rep) -> Entry rep
forall a b. (a -> b) -> a -> b
$ dec -> NameInfo (Aliases rep)
f (dec -> NameInfo (Aliases rep)) -> dec -> NameInfo (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Param dec -> dec
forall dec. Param dec -> dec
paramDec Param dec
fparam
        )
      entries :: VTable rep
entries = [(VName, Entry rep)] -> VTable rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Entry rep)] -> VTable rep)
-> [(VName, Entry rep)] -> VTable rep
forall a b. (a -> b) -> a -> b
$ (Param dec -> (VName, Entry rep))
-> [Param dec] -> [(VName, Entry rep)]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> (VName, Entry rep)
entry [Param dec]
params
   in Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
forall rep.
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VTable rep -> VTable rep -> VTable rep
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable rep
entries VTable rep
vtable) Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y

bindingFParams ::
  [FParam (Aliases rep)] ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingFParams :: [FParam (Aliases rep)] -> ForwardingM rep a -> ForwardingM rep a
bindingFParams = (FParamInfo rep -> NameInfo (Aliases rep))
-> [Param (FParamInfo rep)]
-> ForwardingM rep a
-> ForwardingM rep a
forall dec rep a.
(dec -> NameInfo (Aliases rep))
-> [Param dec] -> ForwardingM rep a -> ForwardingM rep a
bindingParams FParamInfo rep -> NameInfo (Aliases rep)
forall rep. FParamInfo rep -> NameInfo rep
FParamName

bindingScope ::
  Scope (Aliases rep) ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingScope :: Scope (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingScope Scope (Aliases rep)
scope = (TopDown rep -> TopDown rep)
-> ForwardingM rep a -> ForwardingM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown rep -> TopDown rep)
 -> ForwardingM rep a -> ForwardingM rep a)
-> (TopDown rep -> TopDown rep)
-> ForwardingM rep a
-> ForwardingM rep a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable rep
vtable Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y) ->
  let entries :: VTable rep
entries = (NameInfo (Aliases rep) -> Entry rep)
-> Scope (Aliases rep) -> VTable rep
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo (Aliases rep) -> Entry rep
entry Scope (Aliases rep)
scope
      infoAliases :: NameInfo rep -> Names
infoAliases (LetName (aliases, _)) = VarAliases -> Names
unAliases VarAliases
aliases
      infoAliases NameInfo rep
_ = Names
forall a. Monoid a => a
mempty
      entry :: NameInfo (Aliases rep) -> Entry rep
entry NameInfo (Aliases rep)
info = Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
forall rep.
Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
Entry Int
n (NameInfo (Aliases rep) -> Names
forall rep b.
(LetDec rep ~ (VarAliases, b)) =>
NameInfo rep -> Names
infoAliases NameInfo (Aliases rep)
info) Int
d Bool
False NameInfo (Aliases rep)
info
   in Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
forall rep.
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VTable rep
entries VTable rep -> VTable rep -> VTable rep
forall a. Semigroup a => a -> a -> a
<> VTable rep
vtable) Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y

bindingStm ::
  Stm (Aliases rep) ->
  ForwardingM rep a ->
  ForwardingM rep a
bindingStm :: Stm (Aliases rep) -> ForwardingM rep a -> ForwardingM rep a
bindingStm (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
_) = (TopDown rep -> TopDown rep)
-> ForwardingM rep a -> ForwardingM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown rep -> TopDown rep)
 -> ForwardingM rep a -> ForwardingM rep a)
-> (TopDown rep -> TopDown rep)
-> ForwardingM rep a
-> ForwardingM rep a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable rep
vtable Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y) ->
  let entries :: VTable rep
entries = [(VName, Entry rep)] -> VTable rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Entry rep)] -> VTable rep)
-> [(VName, Entry rep)] -> VTable rep
forall a b. (a -> b) -> a -> b
$ (PatElem (VarAliases, LetDec rep) -> (VName, Entry rep))
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, Entry rep)]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (VarAliases, LetDec rep) -> (VName, Entry rep)
entry ([PatElem (VarAliases, LetDec rep)] -> [(VName, Entry rep)])
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, Entry rep)]
forall a b. (a -> b) -> a -> b
$ Pat (VarAliases, LetDec rep) -> [PatElem (VarAliases, LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LetDec rep)
Pat (LetDec (Aliases rep))
pat
      entry :: PatElem (VarAliases, LetDec rep) -> (VName, Entry rep)
entry PatElem (VarAliases, LetDec rep)
patElem =
        let (VarAliases
aliases, LetDec rep
_) = PatElem (VarAliases, LetDec rep) -> (VarAliases, LetDec rep)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LetDec rep)
patElem
         in ( PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LetDec rep)
patElem,
              Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
forall rep.
Int -> Names -> Int -> Bool -> NameInfo (Aliases rep) -> Entry rep
Entry Int
n (VarAliases -> Names
unAliases VarAliases
aliases) Int
d Bool
True (NameInfo (Aliases rep) -> Entry rep)
-> NameInfo (Aliases rep) -> Entry rep
forall a b. (a -> b) -> a -> b
$ LetDec (Aliases rep) -> NameInfo (Aliases rep)
forall rep. LetDec rep -> NameInfo rep
LetName (LetDec (Aliases rep) -> NameInfo (Aliases rep))
-> LetDec (Aliases rep) -> NameInfo (Aliases rep)
forall a b. (a -> b) -> a -> b
$ PatElem (VarAliases, LetDec rep) -> (VarAliases, LetDec rep)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LetDec rep)
patElem
            )
   in Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
forall rep.
Int
-> VTable rep
-> Int
-> LowerUpdate rep (ForwardingM rep)
-> OnOp rep
-> TopDown rep
TopDown (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VTable rep -> VTable rep -> VTable rep
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable rep
entries VTable rep
vtable) Int
d LowerUpdate rep (ForwardingM rep)
x OnOp rep
y

bindingNumber :: VName -> ForwardingM rep Int
bindingNumber :: VName -> ForwardingM rep Int
bindingNumber VName
name = do
  Maybe Int
res <- (TopDown rep -> Maybe Int) -> ForwardingM rep (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown rep -> Maybe Int) -> ForwardingM rep (Maybe Int))
-> (TopDown rep -> Maybe Int) -> ForwardingM rep (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Entry rep -> Int) -> Maybe (Entry rep) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry rep -> Int
forall rep. Entry rep -> Int
entryNumber (Maybe (Entry rep) -> Maybe Int)
-> (TopDown rep -> Maybe (Entry rep)) -> TopDown rep -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry rep) -> Maybe (Entry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry rep) -> Maybe (Entry rep))
-> (TopDown rep -> Map VName (Entry rep))
-> TopDown rep
-> Maybe (Entry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown rep -> Map VName (Entry rep)
forall rep. TopDown rep -> VTable rep
topDownTable
  case Maybe Int
res of
    Just Int
n -> Int -> ForwardingM rep Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n
    Maybe Int
Nothing ->
      String -> ForwardingM rep Int
forall a. HasCallStack => String -> a
error (String -> ForwardingM rep Int) -> String -> ForwardingM rep Int
forall a b. (a -> b) -> a -> b
$
        String
"bindingNumber: variable "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

deepen :: ForwardingM rep a -> ForwardingM rep a
deepen :: ForwardingM rep a -> ForwardingM rep a
deepen = (TopDown rep -> TopDown rep)
-> ForwardingM rep a -> ForwardingM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown rep -> TopDown rep)
 -> ForwardingM rep a -> ForwardingM rep a)
-> (TopDown rep -> TopDown rep)
-> ForwardingM rep a
-> ForwardingM rep a
forall a b. (a -> b) -> a -> b
$ \TopDown rep
env -> TopDown rep
env {topDownDepth :: Int
topDownDepth = TopDown rep -> Int
forall rep. TopDown rep -> Int
topDownDepth TopDown rep
env Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

areAvailableBefore :: Names -> VName -> ForwardingM rep Bool
areAvailableBefore :: Names -> VName -> ForwardingM rep Bool
areAvailableBefore Names
names VName
point = do
  Int
pointN <- VName -> ForwardingM rep Int
forall rep. VName -> ForwardingM rep Int
bindingNumber VName
point
  [Int]
nameNs <- (VName -> ForwardingM rep Int) -> [VName] -> ForwardingM rep [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ForwardingM rep Int
forall rep. VName -> ForwardingM rep Int
bindingNumber ([VName] -> ForwardingM rep [Int])
-> [VName] -> ForwardingM rep [Int]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names
  Bool -> ForwardingM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> ForwardingM rep Bool) -> Bool -> ForwardingM rep Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
pointN) [Int]
nameNs

isInCurrentBody :: VName -> ForwardingM rep Bool
isInCurrentBody :: VName -> ForwardingM rep Bool
isInCurrentBody VName
name = do
  Int
current <- (TopDown rep -> Int) -> ForwardingM rep Int
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown rep -> Int
forall rep. TopDown rep -> Int
topDownDepth
  Maybe Int
res <- (TopDown rep -> Maybe Int) -> ForwardingM rep (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown rep -> Maybe Int) -> ForwardingM rep (Maybe Int))
-> (TopDown rep -> Maybe Int) -> ForwardingM rep (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Entry rep -> Int) -> Maybe (Entry rep) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry rep -> Int
forall rep. Entry rep -> Int
entryDepth (Maybe (Entry rep) -> Maybe Int)
-> (TopDown rep -> Maybe (Entry rep)) -> TopDown rep -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry rep) -> Maybe (Entry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry rep) -> Maybe (Entry rep))
-> (TopDown rep -> Map VName (Entry rep))
-> TopDown rep
-> Maybe (Entry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown rep -> Map VName (Entry rep)
forall rep. TopDown rep -> VTable rep
topDownTable
  case Maybe Int
res of
    Just Int
d -> Bool -> ForwardingM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> ForwardingM rep Bool) -> Bool -> ForwardingM rep Bool
forall a b. (a -> b) -> a -> b
$ Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
current
    Maybe Int
Nothing ->
      String -> ForwardingM rep Bool
forall a. HasCallStack => String -> a
error (String -> ForwardingM rep Bool) -> String -> ForwardingM rep Bool
forall a b. (a -> b) -> a -> b
$
        String
"isInCurrentBody: variable "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

isOptimisable :: VName -> ForwardingM rep Bool
isOptimisable :: VName -> ForwardingM rep Bool
isOptimisable VName
name = do
  Maybe Bool
res <- (TopDown rep -> Maybe Bool) -> ForwardingM rep (Maybe Bool)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown rep -> Maybe Bool) -> ForwardingM rep (Maybe Bool))
-> (TopDown rep -> Maybe Bool) -> ForwardingM rep (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ (Entry rep -> Bool) -> Maybe (Entry rep) -> Maybe Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry rep -> Bool
forall rep. Entry rep -> Bool
entryOptimisable (Maybe (Entry rep) -> Maybe Bool)
-> (TopDown rep -> Maybe (Entry rep)) -> TopDown rep -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry rep) -> Maybe (Entry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry rep) -> Maybe (Entry rep))
-> (TopDown rep -> Map VName (Entry rep))
-> TopDown rep
-> Maybe (Entry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown rep -> Map VName (Entry rep)
forall rep. TopDown rep -> VTable rep
topDownTable
  case Maybe Bool
res of
    Just Bool
b -> Bool -> ForwardingM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
b
    Maybe Bool
Nothing ->
      String -> ForwardingM rep Bool
forall a. HasCallStack => String -> a
error (String -> ForwardingM rep Bool) -> String -> ForwardingM rep Bool
forall a b. (a -> b) -> a -> b
$
        String
"isOptimisable: variable "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

seenVar :: VName -> ForwardingM rep ()
seenVar :: VName -> ForwardingM rep ()
seenVar VName
name = do
  Names
aliases <-
    (TopDown rep -> Names) -> ForwardingM rep Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown rep -> Names) -> ForwardingM rep Names)
-> (TopDown rep -> Names) -> ForwardingM rep Names
forall a b. (a -> b) -> a -> b
$
      Names -> (Entry rep -> Names) -> Maybe (Entry rep) -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty Entry rep -> Names
forall rep. Entry rep -> Names
entryAliases
        (Maybe (Entry rep) -> Names)
-> (TopDown rep -> Maybe (Entry rep)) -> TopDown rep -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry rep) -> Maybe (Entry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name
        (Map VName (Entry rep) -> Maybe (Entry rep))
-> (TopDown rep -> Map VName (Entry rep))
-> TopDown rep
-> Maybe (Entry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown rep -> Map VName (Entry rep)
forall rep. TopDown rep -> VTable rep
topDownTable
  BottomUp rep -> ForwardingM rep ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (BottomUp rep -> ForwardingM rep ())
-> BottomUp rep -> ForwardingM rep ()
forall a b. (a -> b) -> a -> b
$ BottomUp rep
forall a. Monoid a => a
mempty {bottomUpSeen :: Names
bottomUpSeen = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases}

tapBottomUp :: ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
tapBottomUp :: ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
tapBottomUp ForwardingM rep a
m = do
  (a
x, BottomUp rep
bup) <- ForwardingM rep a -> ForwardingM rep (a, BottomUp rep)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ForwardingM rep a
m
  (a, BottomUp rep) -> ForwardingM rep (a, BottomUp rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, BottomUp rep
bup)

maybeForward ::
  Constraints rep =>
  VName ->
  VName ->
  LetDec (Aliases rep) ->
  Certs ->
  VName ->
  Slice SubExp ->
  ForwardingM rep ()
maybeForward :: VName
-> VName
-> LetDec (Aliases rep)
-> Certs
-> VName
-> Slice SubExp
-> ForwardingM rep ()
maybeForward VName
v VName
dest_nm LetDec (Aliases rep)
dest_dec Certs
cs VName
src Slice SubExp
slice = do
  -- Checks condition (2)
  Bool
available <-
    (VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
src Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
      Names -> VName -> ForwardingM rep Bool
forall rep. Names -> VName -> ForwardingM rep Bool
`areAvailableBefore` VName
v
  -- Check condition (3)
  Bool
samebody <- VName -> ForwardingM rep Bool
forall rep. VName -> ForwardingM rep Bool
isInCurrentBody VName
v
  -- Check condition (6)
  Bool
optimisable <- VName -> ForwardingM rep Bool
forall rep. VName -> ForwardingM rep Bool
isOptimisable VName
v
  Bool
not_prim <- Bool -> Bool
not (Bool -> Bool) -> (Type -> Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> ForwardingM rep Type -> ForwardingM rep Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ForwardingM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  Bool -> ForwardingM rep () -> ForwardingM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
available Bool -> Bool -> Bool
&& Bool
samebody Bool -> Bool -> Bool
&& Bool
optimisable Bool -> Bool -> Bool
&& Bool
not_prim) (ForwardingM rep () -> ForwardingM rep ())
-> ForwardingM rep () -> ForwardingM rep ()
forall a b. (a -> b) -> a -> b
$ do
    let fwd :: DesiredUpdate (VarAliases, LetDec rep)
fwd = VName
-> (VarAliases, LetDec rep)
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DesiredUpdate (VarAliases, LetDec rep)
forall dec.
VName
-> dec
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DesiredUpdate dec
DesiredUpdate VName
dest_nm (VarAliases, LetDec rep)
LetDec (Aliases rep)
dest_dec Certs
cs VName
src Slice SubExp
slice VName
v
    BottomUp rep -> ForwardingM rep ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp Any
forall a. Monoid a => a
mempty {forwardThese :: [DesiredUpdate (LetDec (Aliases rep))]
forwardThese = [DesiredUpdate (VarAliases, LetDec rep)
DesiredUpdate (LetDec (Aliases rep))
fwd]}