{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
-- | This module implements an optimisation that moves in-place
-- updates into/before loops where possible, with the end goal of
-- minimising memory copies.  As an example, consider this program:
--
-- @
--   let r =
--     loop (r1 = r0) = for i < n do
--       let a = r1[i] in
--       let r1[i] = a * i in
--       r1
--       in
--   ...
--   let x = y with [k] <- r in
--   ...
-- @
--
-- We want to turn this into the following:
--
-- @
--   let x0 = y with [k] <- r0
--   loop (x = x0) = for i < n do
--     let a = a[k,i] in
--     let x[k,i] = a * i in
--     x
--     in
--   let r = x[y] in
--   ...
-- @
--
-- The intent is that we are also going to optimise the new data
-- movement (in the @x0@-binding), possibly by changing how @r0@ is
-- defined.  For the above transformation to be valid, a number of
-- conditions must be fulfilled:
--
--    (1) @r@ must not be consumed after the original in-place update.
--
--    (2) @k@ and @y@ must be available at the beginning of the loop.
--
--    (3) @x@ must be visible whenever @r@ is visible.  (This means
--    that both @x@ and @r@ must be bound in the same 'Body'.)
--
--    (4) If @x@ is consumed at a point after the loop, @r@ must not
--    be used after that point.
--
--    (5) The size of @r1@ is invariant inside the loop.
--
--    (6) The value @r@ must come from something that we can actually
--    optimise (e.g. not a function parameter).
--
--    (7) @y@ (or its aliases) may not be used inside the body of the
--    loop.
--
--    (8) The result of the loop may not alias the merge parameter
--    @r1@.
--
-- FIXME: the implementation is not finished yet.  Specifically, not
-- all of the above conditions are checked.
module Futhark.Optimise.InPlaceLowering
       (
         inPlaceLowering
       ) where

import Control.Monad.RWS
import qualified Data.Map.Strict as M

import Futhark.Analysis.Alias
import Futhark.Representation.Aliases
import Futhark.Representation.Kernels
import Futhark.Optimise.InPlaceLowering.LowerIntoStm
import Futhark.MonadFreshNames
import Futhark.Binder
import Futhark.Pass
import Futhark.Tools (fullSlice)

-- | Apply the in-place lowering optimisation to the given program.
inPlaceLowering :: Pass Kernels Kernels
inPlaceLowering :: Pass Kernels Kernels
inPlaceLowering =
  String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"In-place lowering" String
"Lower in-place updates into loops" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
  (Prog (Aliases Kernels) -> Prog Kernels)
-> PassM (Prog (Aliases Kernels)) -> PassM (Prog Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases Kernels) -> Prog Kernels
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases (PassM (Prog (Aliases Kernels)) -> PassM (Prog Kernels))
-> (Prog Kernels -> PassM (Prog (Aliases Kernels)))
-> Prog Kernels
-> PassM (Prog Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  (Stms (Aliases Kernels) -> PassM (Stms (Aliases Kernels)))
-> (Stms (Aliases Kernels)
    -> FunDef (Aliases Kernels) -> PassM (FunDef (Aliases Kernels)))
-> Prog (Aliases Kernels)
-> PassM (Prog (Aliases Kernels))
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms (Aliases Kernels) -> PassM (Stms (Aliases Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
Stms (Aliases Kernels) -> m (Stms (Aliases Kernels))
optimiseConsts Stms (Aliases Kernels)
-> FunDef (Aliases Kernels) -> PassM (FunDef (Aliases Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
Stms (Aliases Kernels)
-> FunDef (Aliases Kernels) -> m (FunDef (Aliases Kernels))
optimiseFunDef (Prog (Aliases Kernels) -> PassM (Prog (Aliases Kernels)))
-> (Prog Kernels -> Prog (Aliases Kernels))
-> Prog Kernels
-> PassM (Prog (Aliases Kernels))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  Prog Kernels -> Prog (Aliases Kernels)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
aliasAnalysis

optimiseConsts :: MonadFreshNames m => Stms (Aliases Kernels)
               -> m (Stms (Aliases Kernels))
optimiseConsts :: Stms (Aliases Kernels) -> m (Stms (Aliases Kernels))
optimiseConsts Stms (Aliases Kernels)
stms =
  (VNameSource -> (Stms (Aliases Kernels), VNameSource))
-> m (Stms (Aliases Kernels))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms (Aliases Kernels), VNameSource))
 -> m (Stms (Aliases Kernels)))
-> (VNameSource -> (Stms (Aliases Kernels), VNameSource))
-> m (Stms (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ LowerUpdate Kernels (ForwardingM Kernels)
-> OnOp Kernels
-> ForwardingM Kernels (Stms (Aliases Kernels))
-> VNameSource
-> (Stms (Aliases Kernels), VNameSource)
forall lore a.
LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore a
-> VNameSource
-> (a, VNameSource)
runForwardingM LowerUpdate Kernels (ForwardingM Kernels)
forall (m :: * -> *). MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels OnOp Kernels
onKernelOp (ForwardingM Kernels (Stms (Aliases Kernels))
 -> VNameSource -> (Stms (Aliases Kernels), VNameSource))
-> ForwardingM Kernels (Stms (Aliases Kernels))
-> VNameSource
-> (Stms (Aliases Kernels), VNameSource)
forall a b. (a -> b) -> a -> b
$
  [Stm (Aliases Kernels)] -> Stms (Aliases Kernels)
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm (Aliases Kernels)] -> Stms (Aliases Kernels))
-> ForwardingM Kernels [Stm (Aliases Kernels)]
-> ForwardingM Kernels (Stms (Aliases Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stm (Aliases Kernels)]
-> ForwardingM Kernels ()
-> ForwardingM Kernels [Stm (Aliases Kernels)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms (Stms (Aliases Kernels) -> [Stm (Aliases Kernels)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases Kernels)
stms) (() -> ForwardingM Kernels ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

optimiseFunDef :: MonadFreshNames m =>
                  Stms (Aliases Kernels) -> FunDef (Aliases Kernels)
               -> m (FunDef (Aliases Kernels))
optimiseFunDef :: Stms (Aliases Kernels)
-> FunDef (Aliases Kernels) -> m (FunDef (Aliases Kernels))
optimiseFunDef Stms (Aliases Kernels)
consts FunDef (Aliases Kernels)
fundec =
  (VNameSource -> (FunDef (Aliases Kernels), VNameSource))
-> m (FunDef (Aliases Kernels))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (FunDef (Aliases Kernels), VNameSource))
 -> m (FunDef (Aliases Kernels)))
-> (VNameSource -> (FunDef (Aliases Kernels), VNameSource))
-> m (FunDef (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ LowerUpdate Kernels (ForwardingM Kernels)
-> OnOp Kernels
-> ForwardingM Kernels (FunDef (Aliases Kernels))
-> VNameSource
-> (FunDef (Aliases Kernels), VNameSource)
forall lore a.
LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore a
-> VNameSource
-> (a, VNameSource)
runForwardingM LowerUpdate Kernels (ForwardingM Kernels)
forall (m :: * -> *). MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels OnOp Kernels
onKernelOp (ForwardingM Kernels (FunDef (Aliases Kernels))
 -> VNameSource -> (FunDef (Aliases Kernels), VNameSource))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
-> VNameSource
-> (FunDef (Aliases Kernels), VNameSource)
forall a b. (a -> b) -> a -> b
$
  [Stm (Aliases Kernels)]
-> ForwardingM Kernels (FunDef (Aliases Kernels))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
forall lore a.
[Stm (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
descend (Stms (Aliases Kernels) -> [Stm (Aliases Kernels)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases Kernels)
consts) (ForwardingM Kernels (FunDef (Aliases Kernels))
 -> ForwardingM Kernels (FunDef (Aliases Kernels)))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ [FParam (Aliases Kernels)]
-> ForwardingM Kernels (FunDef (Aliases Kernels))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
forall lore a.
[FParam (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
bindingFParams (FunDef (Aliases Kernels) -> [FParam (Aliases Kernels)]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef (Aliases Kernels)
fundec) (ForwardingM Kernels (FunDef (Aliases Kernels))
 -> ForwardingM Kernels (FunDef (Aliases Kernels)))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
-> ForwardingM Kernels (FunDef (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ do
    Body (Aliases Kernels)
body <- Body (Aliases Kernels)
-> ForwardingM Kernels (Body (Aliases Kernels))
forall lore.
Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody (Body (Aliases Kernels)
 -> ForwardingM Kernels (Body (Aliases Kernels)))
-> Body (Aliases Kernels)
-> ForwardingM Kernels (Body (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases Kernels) -> Body (Aliases Kernels)
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef (Aliases Kernels)
fundec
    FunDef (Aliases Kernels)
-> ForwardingM Kernels (FunDef (Aliases Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef (Aliases Kernels)
 -> ForwardingM Kernels (FunDef (Aliases Kernels)))
-> FunDef (Aliases Kernels)
-> ForwardingM Kernels (FunDef (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases Kernels)
fundec { funDefBody :: Body (Aliases Kernels)
funDefBody = Body (Aliases Kernels)
body }
  where descend :: [Stm (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
descend [] ForwardingM lore a
m = ForwardingM lore a
m
        descend (Stm (Aliases lore)
stm:[Stm (Aliases lore)]
stms) ForwardingM lore a
m = Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
forall lore a.
Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingStm Stm (Aliases lore)
stm (ForwardingM lore a -> ForwardingM lore a)
-> ForwardingM lore a -> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
descend [Stm (Aliases lore)]
stms ForwardingM lore a
m

type Constraints lore = (Bindable lore, CanBeAliased (Op lore))

optimiseBody :: Constraints lore =>
                Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody :: Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody (Body BodyAttr (Aliases lore)
als Stms (Aliases lore)
bnds Result
res) = do
  [Stm (Aliases lore)]
bnds' <- ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall lore a. ForwardingM lore a -> ForwardingM lore a
deepen (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms (Stms (Aliases lore) -> [Stm (Aliases lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Aliases lore)
bnds) (ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
    (SubExp -> ForwardingM lore ()) -> Result -> ForwardingM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> ForwardingM lore ()
forall lore. SubExp -> ForwardingM lore ()
seen Result
res
  Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore)))
-> Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
forall a b. (a -> b) -> a -> b
$ BodyAttr (Aliases lore)
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr (Aliases lore)
als ([Stm (Aliases lore)] -> Stms (Aliases lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Aliases lore)]
bnds') Result
res
  where seen :: SubExp -> ForwardingM lore ()
seen Constant{} = () -> ForwardingM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        seen (Var VName
v)    = VName -> ForwardingM lore ()
forall lore. VName -> ForwardingM lore ()
seenVar VName
v

optimiseStms :: Constraints lore =>
                [Stm (Aliases lore)] -> ForwardingM lore ()
             -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms :: [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [] ForwardingM lore ()
m = ForwardingM lore ()
m ForwardingM lore ()
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return []

optimiseStms (Stm (Aliases lore)
bnd:[Stm (Aliases lore)]
bnds) ForwardingM lore ()
m = do
  ([Stm (Aliases lore)]
bnds', BottomUp lore
bup) <- ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore ([Stm (Aliases lore)], BottomUp lore)
forall lore a.
ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore ([Stm (Aliases lore)], BottomUp lore))
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore ([Stm (Aliases lore)], BottomUp lore)
forall a b. (a -> b) -> a -> b
$ Stm (Aliases lore)
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall lore a.
Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingStm Stm (Aliases lore)
bnd (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [Stm (Aliases lore)]
bnds ForwardingM lore ()
m
  Stm (Aliases lore)
bnd' <- Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
forall lore.
Constraints lore =>
Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm Stm (Aliases lore)
bnd
  case (DesiredUpdate (VarAliases, LetAttr lore) -> Bool)
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> [DesiredUpdate (VarAliases, LetAttr lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
boundHere) (VName -> Bool)
-> (DesiredUpdate (VarAliases, LetAttr lore) -> VName)
-> DesiredUpdate (VarAliases, LetAttr lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, LetAttr lore) -> VName
forall attr. DesiredUpdate attr -> VName
updateValue) ([DesiredUpdate (VarAliases, LetAttr lore)]
 -> [DesiredUpdate (VarAliases, LetAttr lore)])
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> [DesiredUpdate (VarAliases, LetAttr lore)]
forall a b. (a -> b) -> a -> b
$
       BottomUp lore -> [DesiredUpdate (LetAttr (Aliases lore))]
forall lore.
BottomUp lore -> [DesiredUpdate (LetAttr (Aliases lore))]
forwardThese BottomUp lore
bup of
    [] -> Stm (Aliases lore)
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall lore lore.
(Bindable lore, CanBeAliased (Op lore),
 LetAttr lore ~ (VarAliases, LetAttr lore)) =>
Stm lore -> [Stm lore] -> ForwardingM lore [Stm lore]
checkIfForwardableUpdate Stm (Aliases lore)
bnd' [Stm (Aliases lore)]
bnds'
    [DesiredUpdate (VarAliases, LetAttr lore)]
updates -> do
      let updateStms :: [Stm (Aliases lore)]
updateStms = (DesiredUpdate (VarAliases, LetAttr lore) -> Stm (Aliases lore))
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> [Stm (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map DesiredUpdate (VarAliases, LetAttr lore) -> Stm (Aliases lore)
forall lore.
Constraints lore =>
DesiredUpdate (LetAttr (Aliases lore)) -> Stm (Aliases lore)
updateStm [DesiredUpdate (VarAliases, LetAttr lore)]
updates
      Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> Maybe (ForwardingM lore [Stm (Aliases lore)])
lower <- (TopDown lore
 -> Scope (Aliases lore)
 -> Stm (Aliases lore)
 -> [DesiredUpdate (VarAliases, LetAttr lore)]
 -> Maybe (ForwardingM lore [Stm (Aliases lore)]))
-> ForwardingM
     lore
     (Scope (Aliases lore)
      -> Stm (Aliases lore)
      -> [DesiredUpdate (VarAliases, LetAttr lore)]
      -> Maybe (ForwardingM lore [Stm (Aliases lore)]))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore
-> Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> Maybe (ForwardingM lore [Stm (Aliases lore)])
forall lore. TopDown lore -> LowerUpdate lore (ForwardingM lore)
lowerUpdate
      Scope (Aliases lore)
scope <- ForwardingM lore (Scope (Aliases lore))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
      -- Condition (5) and (7) are assumed to be checked by
      -- lowerUpdate.
      case Scope (Aliases lore)
-> Stm (Aliases lore)
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> Maybe (ForwardingM lore [Stm (Aliases lore)])
lower Scope (Aliases lore)
scope Stm (Aliases lore)
bnd' [DesiredUpdate (VarAliases, LetAttr lore)]
updates of
        Just ForwardingM lore [Stm (Aliases lore)]
lowering -> do [Stm (Aliases lore)]
new_bnds <- ForwardingM lore [Stm (Aliases lore)]
lowering
                            [Stm (Aliases lore)]
new_bnds' <- [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [Stm (Aliases lore)]
new_bnds (ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
                                         BottomUp lore -> ForwardingM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp lore
bup { forwardThese :: [DesiredUpdate (LetAttr (Aliases lore))]
forwardThese = [] }
                            [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)]
new_bnds' [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
bnds'
        Maybe (ForwardingM lore [Stm (Aliases lore)])
Nothing       -> Stm (Aliases lore)
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall lore lore.
(Bindable lore, CanBeAliased (Op lore),
 LetAttr lore ~ (VarAliases, LetAttr lore)) =>
Stm lore -> [Stm lore] -> ForwardingM lore [Stm lore]
checkIfForwardableUpdate Stm (Aliases lore)
bnd' ([Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)])
-> [Stm (Aliases lore)] -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
                         [Stm (Aliases lore)]
updateStms [Stm (Aliases lore)]
-> [Stm (Aliases lore)] -> [Stm (Aliases lore)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases lore)]
bnds'

  where boundHere :: [VName]
boundHere = PatternT (VarAliases, LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (VarAliases, LetAttr lore) -> [VName])
-> PatternT (VarAliases, LetAttr lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm (Aliases lore) -> Pattern (Aliases lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm (Aliases lore)
bnd

        checkIfForwardableUpdate :: Stm lore -> [Stm lore] -> ForwardingM lore [Stm lore]
checkIfForwardableUpdate bnd' :: Stm lore
bnd'@(Let (Pattern [] [PatElem VName
v LetAttr lore
attr])
                                       (StmAux Certificates
cs ExpAttr lore
_) Exp lore
e) [Stm lore]
bnds'
            | BasicOp (Update VName
src (DimFix SubExp
i:[DimIndex SubExp]
slice) (Var VName
ve)) <- Exp lore
e,
              [DimIndex SubExp]
slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 (Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice ((VarAliases, LetAttr lore) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, LetAttr lore)
LetAttr lore
attr) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) = do
                Bool
forwarded <- VName
-> VName
-> LetAttr (Aliases lore)
-> Certificates
-> VName
-> SubExp
-> ForwardingM lore Bool
forall lore.
Constraints lore =>
VName
-> VName
-> LetAttr (Aliases lore)
-> Certificates
-> VName
-> SubExp
-> ForwardingM lore Bool
maybeForward VName
ve VName
v LetAttr lore
LetAttr (Aliases lore)
attr Certificates
cs VName
src SubExp
i
                [Stm lore] -> ForwardingM lore [Stm lore]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> ForwardingM lore [Stm lore])
-> [Stm lore] -> ForwardingM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ if Bool
forwarded
                         then [Stm lore]
bnds'
                         else Stm lore
bnd' Stm lore -> [Stm lore] -> [Stm lore]
forall a. a -> [a] -> [a]
: [Stm lore]
bnds'
        checkIfForwardableUpdate Stm lore
bnd' [Stm lore]
bnds' =
          [Stm lore] -> ForwardingM lore [Stm lore]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> ForwardingM lore [Stm lore])
-> [Stm lore] -> ForwardingM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ Stm lore
bnd' Stm lore -> [Stm lore] -> [Stm lore]
forall a. a -> [a] -> [a]
: [Stm lore]
bnds'

optimiseInStm :: Constraints lore => Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm :: Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm (Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
attr Exp (Aliases lore)
e) =
  Pattern (Aliases lore)
-> StmAux (ExpAttr (Aliases lore))
-> Exp (Aliases lore)
-> Stm (Aliases lore)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
attr (Exp (Aliases lore) -> Stm (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Stm (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
forall lore.
Constraints lore =>
Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp Exp (Aliases lore)
e

optimiseExp :: Constraints lore => Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp :: Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp (DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form BodyT (Aliases lore)
body) =
  Scope (Aliases lore)
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall lore a.
Scope (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingScope (LoopForm (Aliases lore) -> Scope (Aliases lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm (Aliases lore)
form) (ForwardingM lore (Exp (Aliases lore))
 -> ForwardingM lore (Exp (Aliases lore)))
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall a b. (a -> b) -> a -> b
$
  [FParam (Aliases lore)]
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall lore a.
[FParam (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
bindingFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam (Aliases lore), SubExp)]
val) (ForwardingM lore (Exp (Aliases lore))
 -> ForwardingM lore (Exp (Aliases lore)))
-> ForwardingM lore (Exp (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall a b. (a -> b) -> a -> b
$
  [(FParam (Aliases lore), SubExp)]
-> [(FParam (Aliases lore), SubExp)]
-> LoopForm (Aliases lore)
-> BodyT (Aliases lore)
-> Exp (Aliases lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam (Aliases lore), SubExp)]
ctx [(FParam (Aliases lore), SubExp)]
val LoopForm (Aliases lore)
form (BodyT (Aliases lore) -> Exp (Aliases lore))
-> ForwardingM lore (BodyT (Aliases lore))
-> ForwardingM lore (Exp (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore))
forall lore.
Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody BodyT (Aliases lore)
body
optimiseExp (Op Op (Aliases lore)
op) = do
  OpWithAliases (Op lore)
-> ForwardingM lore (OpWithAliases (Op lore))
f <- (TopDown lore
 -> OpWithAliases (Op lore)
 -> ForwardingM lore (OpWithAliases (Op lore)))
-> ForwardingM
     lore
     (OpWithAliases (Op lore)
      -> ForwardingM lore (OpWithAliases (Op lore)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore
-> OpWithAliases (Op lore)
-> ForwardingM lore (OpWithAliases (Op lore))
forall lore. TopDown lore -> OnOp lore
onOp
  OpWithAliases (Op lore) -> Exp (Aliases lore)
forall lore. Op lore -> ExpT lore
Op (OpWithAliases (Op lore) -> Exp (Aliases lore))
-> ForwardingM lore (OpWithAliases (Op lore))
-> ForwardingM lore (Exp (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpWithAliases (Op lore)
-> ForwardingM lore (OpWithAliases (Op lore))
f Op (Aliases lore)
OpWithAliases (Op lore)
op
optimiseExp Exp (Aliases lore)
e = Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
-> Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
optimise Exp (Aliases lore)
e
  where optimise :: Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
optimise = Mapper (Aliases lore) (Aliases lore) (ForwardingM lore)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope (Aliases lore)
-> BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore))
mapOnBody = (BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore)))
-> Scope (Aliases lore)
-> BodyT (Aliases lore)
-> ForwardingM lore (BodyT (Aliases lore))
forall a b. a -> b -> a
const BodyT (Aliases lore) -> ForwardingM lore (BodyT (Aliases lore))
forall lore.
Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody
                                  }
onKernelOp :: OnOp Kernels
onKernelOp :: OnOp Kernels
onKernelOp (SegOp op) =
  Scope (Aliases Kernels)
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall lore a.
Scope (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingScope (SegSpace -> Scope (Aliases Kernels)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegOp (Aliases Kernels) -> SegSpace
forall lore. SegOp lore -> SegSpace
segSpace SegOp (Aliases Kernels)
op)) (ForwardingM
   Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
 -> ForwardingM
      Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels))))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall a b. (a -> b) -> a -> b
$ do
    let mapper :: SegOpMapper
  (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
mapper = SegOpMapper
  (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper { mapOnSegOpBody :: KernelBody (Aliases Kernels)
-> ForwardingM Kernels (KernelBody (Aliases Kernels))
mapOnSegOpBody = KernelBody (Aliases Kernels)
-> ForwardingM Kernels (KernelBody (Aliases Kernels))
forall lore.
(Bindable lore, CanBeAliased (Op lore)) =>
KernelBody (Aliases lore)
-> ForwardingM lore (KernelBody (Aliases lore))
onKernelBody }
        onKernelBody :: KernelBody (Aliases lore)
-> ForwardingM lore (KernelBody (Aliases lore))
onKernelBody KernelBody (Aliases lore)
kbody = do
          [Stm (Aliases lore)]
stms <- ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall lore a. ForwardingM lore a -> ForwardingM lore a
deepen (ForwardingM lore [Stm (Aliases lore)]
 -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore [Stm (Aliases lore)]
-> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$ [Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall lore.
Constraints lore =>
[Stm (Aliases lore)]
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
optimiseStms (Stms (Aliases lore) -> [Stm (Aliases lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList (KernelBody (Aliases lore) -> Stms (Aliases lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody (Aliases lore)
kbody)) (ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)])
-> ForwardingM lore () -> ForwardingM lore [Stm (Aliases lore)]
forall a b. (a -> b) -> a -> b
$
                  (VName -> ForwardingM lore ()) -> [VName] -> ForwardingM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> ForwardingM lore ()
forall lore. VName -> ForwardingM lore ()
seenVar ([VName] -> ForwardingM lore ()) -> [VName] -> ForwardingM lore ()
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn ([KernelResult] -> Names) -> [KernelResult] -> Names
forall a b. (a -> b) -> a -> b
$ KernelBody (Aliases lore) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Aliases lore)
kbody
          KernelBody (Aliases lore)
-> ForwardingM lore (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody (Aliases lore)
kbody { kernelBodyStms :: Stms (Aliases lore)
kernelBodyStms = [Stm (Aliases lore)] -> Stms (Aliases lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Aliases lore)]
stms }
    SegOp (Aliases Kernels)
-> HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Aliases Kernels)
 -> HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
-> ForwardingM Kernels (SegOp (Aliases Kernels))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
-> SegOp (Aliases Kernels)
-> ForwardingM Kernels (SegOp (Aliases Kernels))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper
  (Aliases Kernels) (Aliases Kernels) (ForwardingM Kernels)
mapper SegOp (Aliases Kernels)
op
onKernelOp Op (Aliases Kernels)
op = HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
-> ForwardingM
     Kernels (HostOp (Aliases Kernels) (SOAC (Aliases Kernels)))
forall (m :: * -> *) a. Monad m => a -> m a
return Op (Aliases Kernels)
HostOp (Aliases Kernels) (SOAC (Aliases Kernels))
op

data Entry lore = Entry { Entry lore -> Int
entryNumber :: Int
                        , Entry lore -> Names
entryAliases :: Names
                        , Entry lore -> Int
entryDepth :: Int
                        , Entry lore -> Bool
entryOptimisable :: Bool
                        , Entry lore -> NameInfo (Aliases lore)
entryType :: NameInfo (Aliases lore)
                        }

type VTable lore = M.Map VName (Entry lore)

type OnOp lore = Op (Aliases lore) -> ForwardingM lore (Op (Aliases lore))

data TopDown lore = TopDown { TopDown lore -> Int
topDownCounter :: Int
                            , TopDown lore -> VTable lore
topDownTable :: VTable lore
                            , TopDown lore -> Int
topDownDepth :: Int
                            , TopDown lore -> LowerUpdate lore (ForwardingM lore)
lowerUpdate :: LowerUpdate lore (ForwardingM lore)
                            , TopDown lore -> OnOp lore
onOp :: OnOp lore
                            }

data BottomUp lore = BottomUp { BottomUp lore -> Names
bottomUpSeen :: Names
                              , BottomUp lore -> [DesiredUpdate (LetAttr (Aliases lore))]
forwardThese :: [DesiredUpdate (LetAttr (Aliases lore))]
                              }

instance Semigroup (BottomUp lore) where
  BottomUp Names
seen1 [DesiredUpdate (LetAttr (Aliases lore))]
forward1 <> :: BottomUp lore -> BottomUp lore -> BottomUp lore
<> BottomUp Names
seen2 [DesiredUpdate (LetAttr (Aliases lore))]
forward2 =
    Names -> [DesiredUpdate (LetAttr (Aliases lore))] -> BottomUp lore
forall lore.
Names -> [DesiredUpdate (LetAttr (Aliases lore))] -> BottomUp lore
BottomUp (Names
seen1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
seen2) ([DesiredUpdate (VarAliases, LetAttr lore)]
[DesiredUpdate (LetAttr (Aliases lore))]
forward1 [DesiredUpdate (VarAliases, LetAttr lore)]
-> [DesiredUpdate (VarAliases, LetAttr lore)]
-> [DesiredUpdate (VarAliases, LetAttr lore)]
forall a. Semigroup a => a -> a -> a
<> [DesiredUpdate (VarAliases, LetAttr lore)]
[DesiredUpdate (LetAttr (Aliases lore))]
forward2)

instance Monoid (BottomUp lore) where
  mempty :: BottomUp lore
mempty = Names -> [DesiredUpdate (LetAttr (Aliases lore))] -> BottomUp lore
forall lore.
Names -> [DesiredUpdate (LetAttr (Aliases lore))] -> BottomUp lore
BottomUp Names
forall a. Monoid a => a
mempty [DesiredUpdate (LetAttr (Aliases lore))]
forall a. Monoid a => a
mempty

updateStm :: Constraints lore => DesiredUpdate (LetAttr (Aliases lore)) -> Stm (Aliases lore)
updateStm :: DesiredUpdate (LetAttr (Aliases lore)) -> Stm (Aliases lore)
updateStm DesiredUpdate (LetAttr (Aliases lore))
fwd =
  [Ident] -> [Ident] -> Exp (Aliases lore) -> Stm (Aliases lore)
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident (DesiredUpdate (VarAliases, LetAttr lore) -> VName
forall attr. DesiredUpdate attr -> VName
updateName DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd) (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, LetAttr lore) -> Type
forall t. Typed t => t -> Type
typeOf ((VarAliases, LetAttr lore) -> Type)
-> (VarAliases, LetAttr lore) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (VarAliases, LetAttr lore)
-> (VarAliases, LetAttr lore)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd] (Exp (Aliases lore) -> Stm (Aliases lore))
-> Exp (Aliases lore) -> Stm (Aliases lore)
forall a b. (a -> b) -> a -> b
$
  BasicOp (Aliases lore) -> Exp (Aliases lore)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Aliases lore) -> Exp (Aliases lore))
-> BasicOp (Aliases lore) -> Exp (Aliases lore)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> SubExp -> BasicOp (Aliases lore)
forall lore. VName -> [DimIndex SubExp] -> SubExp -> BasicOp lore
Update (DesiredUpdate (VarAliases, LetAttr lore) -> VName
forall attr. DesiredUpdate attr -> VName
updateSource DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd)
  (Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice ((VarAliases, LetAttr lore) -> Type
forall t. Typed t => t -> Type
typeOf ((VarAliases, LetAttr lore) -> Type)
-> (VarAliases, LetAttr lore) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (VarAliases, LetAttr lore)
-> (VarAliases, LetAttr lore)
forall attr. DesiredUpdate attr -> attr
updateType DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd) ([DimIndex SubExp] -> [DimIndex SubExp])
-> [DimIndex SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (VarAliases, LetAttr lore) -> [DimIndex SubExp]
forall attr. DesiredUpdate attr -> [DimIndex SubExp]
updateIndices DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd) (SubExp -> BasicOp (Aliases lore))
-> SubExp -> BasicOp (Aliases lore)
forall a b. (a -> b) -> a -> b
$
  VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (VarAliases, LetAttr lore) -> VName
forall attr. DesiredUpdate attr -> VName
updateValue DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd

newtype ForwardingM lore a = ForwardingM (RWS (TopDown lore) (BottomUp lore) VNameSource a)
                      deriving (Applicative (ForwardingM lore)
a -> ForwardingM lore a
Applicative (ForwardingM lore)
-> (forall a b.
    ForwardingM lore a
    -> (a -> ForwardingM lore b) -> ForwardingM lore b)
-> (forall a b.
    ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b)
-> (forall a. a -> ForwardingM lore a)
-> Monad (ForwardingM lore)
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall lore. Applicative (ForwardingM lore)
forall a. a -> ForwardingM lore a
forall lore a. a -> ForwardingM lore a
forall a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall a b.
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall lore a b.
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ForwardingM lore a
$creturn :: forall lore a. a -> ForwardingM lore a
>> :: ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
$c>> :: forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
>>= :: ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
$c>>= :: forall lore a b.
ForwardingM lore a
-> (a -> ForwardingM lore b) -> ForwardingM lore b
$cp1Monad :: forall lore. Applicative (ForwardingM lore)
Monad, Functor (ForwardingM lore)
a -> ForwardingM lore a
Functor (ForwardingM lore)
-> (forall a. a -> ForwardingM lore a)
-> (forall a b.
    ForwardingM lore (a -> b)
    -> ForwardingM lore a -> ForwardingM lore b)
-> (forall a b c.
    (a -> b -> c)
    -> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c)
-> (forall a b.
    ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b)
-> (forall a b.
    ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a)
-> Applicative (ForwardingM lore)
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
forall lore. Functor (ForwardingM lore)
forall a. a -> ForwardingM lore a
forall lore a. a -> ForwardingM lore a
forall a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
forall a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall a b.
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
forall lore a b.
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
forall a b c.
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
forall lore a b c.
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
$c<* :: forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore a
*> :: ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
$c*> :: forall lore a b.
ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore b
liftA2 :: (a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c)
-> ForwardingM lore a -> ForwardingM lore b -> ForwardingM lore c
<*> :: ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
$c<*> :: forall lore a b.
ForwardingM lore (a -> b)
-> ForwardingM lore a -> ForwardingM lore b
pure :: a -> ForwardingM lore a
$cpure :: forall lore a. a -> ForwardingM lore a
$cp1Applicative :: forall lore. Functor (ForwardingM lore)
Applicative, a -> ForwardingM lore b -> ForwardingM lore a
(a -> b) -> ForwardingM lore a -> ForwardingM lore b
(forall a b. (a -> b) -> ForwardingM lore a -> ForwardingM lore b)
-> (forall a b. a -> ForwardingM lore b -> ForwardingM lore a)
-> Functor (ForwardingM lore)
forall a b. a -> ForwardingM lore b -> ForwardingM lore a
forall a b. (a -> b) -> ForwardingM lore a -> ForwardingM lore b
forall lore a b. a -> ForwardingM lore b -> ForwardingM lore a
forall lore a b.
(a -> b) -> ForwardingM lore a -> ForwardingM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ForwardingM lore b -> ForwardingM lore a
$c<$ :: forall lore a b. a -> ForwardingM lore b -> ForwardingM lore a
fmap :: (a -> b) -> ForwardingM lore a -> ForwardingM lore b
$cfmap :: forall lore a b.
(a -> b) -> ForwardingM lore a -> ForwardingM lore b
Functor,
                                MonadReader (TopDown lore),
                                MonadWriter (BottomUp lore),
                                MonadState VNameSource)

instance MonadFreshNames (ForwardingM lore) where
  getNameSource :: ForwardingM lore VNameSource
getNameSource = ForwardingM lore VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
  putNameSource :: VNameSource -> ForwardingM lore ()
putNameSource = VNameSource -> ForwardingM lore ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance Constraints lore => HasScope (Aliases lore) (ForwardingM lore) where
  askScope :: ForwardingM lore (Scope (Aliases lore))
askScope = (Entry lore -> NameInfo (Aliases lore))
-> Map VName (Entry lore) -> Scope (Aliases lore)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Entry lore -> NameInfo (Aliases lore)
forall lore. Entry lore -> NameInfo (Aliases lore)
entryType (Map VName (Entry lore) -> Scope (Aliases lore))
-> ForwardingM lore (Map VName (Entry lore))
-> ForwardingM lore (Scope (Aliases lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TopDown lore -> Map VName (Entry lore))
-> ForwardingM lore (Map VName (Entry lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable

runForwardingM :: LowerUpdate lore (ForwardingM lore) -> OnOp lore -> ForwardingM lore a
               -> VNameSource -> (a, VNameSource)
runForwardingM :: LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> ForwardingM lore a
-> VNameSource
-> (a, VNameSource)
runForwardingM LowerUpdate lore (ForwardingM lore)
f OnOp lore
g (ForwardingM RWS (TopDown lore) (BottomUp lore) VNameSource a
m) VNameSource
src = let (a
x, VNameSource
src', BottomUp lore
_) = RWS (TopDown lore) (BottomUp lore) VNameSource a
-> TopDown lore -> VNameSource -> (a, VNameSource, BottomUp lore)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (TopDown lore) (BottomUp lore) VNameSource a
m TopDown lore
emptyTopDown VNameSource
src
                                         in (a
x, VNameSource
src')
  where emptyTopDown :: TopDown lore
emptyTopDown = TopDown :: forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown { topDownCounter :: Int
topDownCounter = Int
0
                               , topDownTable :: VTable lore
topDownTable = VTable lore
forall k a. Map k a
M.empty
                               , topDownDepth :: Int
topDownDepth = Int
0
                               , lowerUpdate :: LowerUpdate lore (ForwardingM lore)
lowerUpdate = LowerUpdate lore (ForwardingM lore)
f
                               , onOp :: OnOp lore
onOp = OnOp lore
g
                               }

bindingParams :: (attr -> NameInfo (Aliases lore))
              -> [Param attr]
               -> ForwardingM lore a
               -> ForwardingM lore a
bindingParams :: (attr -> NameInfo (Aliases lore))
-> [Param attr] -> ForwardingM lore a -> ForwardingM lore a
bindingParams attr -> NameInfo (Aliases lore)
f [Param attr]
params = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable lore
vtable Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y) ->
  let entry :: Param attr -> (VName, Entry lore)
entry Param attr
fparam =
        (Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
fparam,
         Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
forall lore.
Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
Entry Int
n Names
forall a. Monoid a => a
mempty Int
d Bool
False (NameInfo (Aliases lore) -> Entry lore)
-> NameInfo (Aliases lore) -> Entry lore
forall a b. (a -> b) -> a -> b
$ attr -> NameInfo (Aliases lore)
f (attr -> NameInfo (Aliases lore))
-> attr -> NameInfo (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Param attr -> attr
forall attr. Param attr -> attr
paramAttr Param attr
fparam)
      entries :: VTable lore
entries = [(VName, Entry lore)] -> VTable lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Entry lore)] -> VTable lore)
-> [(VName, Entry lore)] -> VTable lore
forall a b. (a -> b) -> a -> b
$ (Param attr -> (VName, Entry lore))
-> [Param attr] -> [(VName, Entry lore)]
forall a b. (a -> b) -> [a] -> [b]
map Param attr -> (VName, Entry lore)
entry [Param attr]
params
  in Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (VTable lore -> VTable lore -> VTable lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable lore
entries VTable lore
vtable) Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y

bindingFParams :: [FParam (Aliases lore)]
               -> ForwardingM lore a
               -> ForwardingM lore a
bindingFParams :: [FParam (Aliases lore)] -> ForwardingM lore a -> ForwardingM lore a
bindingFParams = (FParamAttr lore -> NameInfo (Aliases lore))
-> [Param (FParamAttr lore)]
-> ForwardingM lore a
-> ForwardingM lore a
forall attr lore a.
(attr -> NameInfo (Aliases lore))
-> [Param attr] -> ForwardingM lore a -> ForwardingM lore a
bindingParams FParamAttr lore -> NameInfo (Aliases lore)
forall lore. FParamAttr lore -> NameInfo lore
FParamInfo

bindingScope :: Scope (Aliases lore)
             -> ForwardingM lore a
             -> ForwardingM lore a
bindingScope :: Scope (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingScope Scope (Aliases lore)
scope = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable lore
vtable Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y) ->
  let entries :: VTable lore
entries = (NameInfo (Aliases lore) -> Entry lore)
-> Scope (Aliases lore) -> VTable lore
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo (Aliases lore) -> Entry lore
entry Scope (Aliases lore)
scope
      infoAliases :: NameInfo lore -> Names
infoAliases (LetInfo (aliases, _)) = VarAliases -> Names
unNames VarAliases
aliases
      infoAliases NameInfo lore
_ = Names
forall a. Monoid a => a
mempty
      entry :: NameInfo (Aliases lore) -> Entry lore
entry NameInfo (Aliases lore)
info = Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
forall lore.
Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
Entry Int
n (NameInfo (Aliases lore) -> Names
forall lore b.
(LetAttr lore ~ (VarAliases, b)) =>
NameInfo lore -> Names
infoAliases NameInfo (Aliases lore)
info) Int
d Bool
False NameInfo (Aliases lore)
info
  in Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (VTable lore
entriesVTable lore -> VTable lore -> VTable lore
forall a. Semigroup a => a -> a -> a
<>VTable lore
vtable) Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y

bindingStm :: Stm (Aliases lore)
           -> ForwardingM lore a
           -> ForwardingM lore a
bindingStm :: Stm (Aliases lore) -> ForwardingM lore a -> ForwardingM lore a
bindingStm (Let Pattern (Aliases lore)
pat StmAux (ExpAttr (Aliases lore))
_ Exp (Aliases lore)
_) = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \(TopDown Int
n VTable lore
vtable Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y) ->
  let entries :: VTable lore
entries = [(VName, Entry lore)] -> VTable lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Entry lore)] -> VTable lore)
-> [(VName, Entry lore)] -> VTable lore
forall a b. (a -> b) -> a -> b
$ (PatElemT (VarAliases, LetAttr lore) -> (VName, Entry lore))
-> [PatElemT (VarAliases, LetAttr lore)] -> [(VName, Entry lore)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (VarAliases, LetAttr lore) -> (VName, Entry lore)
entry ([PatElemT (VarAliases, LetAttr lore)] -> [(VName, Entry lore)])
-> [PatElemT (VarAliases, LetAttr lore)] -> [(VName, Entry lore)]
forall a b. (a -> b) -> a -> b
$ PatternT (VarAliases, LetAttr lore)
-> [PatElemT (VarAliases, LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT (VarAliases, LetAttr lore)
Pattern (Aliases lore)
pat
      entry :: PatElemT (VarAliases, LetAttr lore) -> (VName, Entry lore)
entry PatElemT (VarAliases, LetAttr lore)
patElem =
        let (VarAliases
aliases, LetAttr lore
_) = PatElemT (VarAliases, LetAttr lore) -> (VarAliases, LetAttr lore)
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (VarAliases, LetAttr lore)
patElem
        in (PatElemT (VarAliases, LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarAliases, LetAttr lore)
patElem,
            Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
forall lore.
Int
-> Names -> Int -> Bool -> NameInfo (Aliases lore) -> Entry lore
Entry Int
n (VarAliases -> Names
unNames VarAliases
aliases) Int
d Bool
True (NameInfo (Aliases lore) -> Entry lore)
-> NameInfo (Aliases lore) -> Entry lore
forall a b. (a -> b) -> a -> b
$ LetAttr (Aliases lore) -> NameInfo (Aliases lore)
forall lore. LetAttr lore -> NameInfo lore
LetInfo (LetAttr (Aliases lore) -> NameInfo (Aliases lore))
-> LetAttr (Aliases lore) -> NameInfo (Aliases lore)
forall a b. (a -> b) -> a -> b
$ PatElemT (VarAliases, LetAttr lore) -> (VarAliases, LetAttr lore)
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (VarAliases, LetAttr lore)
patElem)
  in Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
forall lore.
Int
-> VTable lore
-> Int
-> LowerUpdate lore (ForwardingM lore)
-> OnOp lore
-> TopDown lore
TopDown (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (VTable lore -> VTable lore -> VTable lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union VTable lore
entries VTable lore
vtable) Int
d LowerUpdate lore (ForwardingM lore)
x OnOp lore
y

bindingNumber :: VName -> ForwardingM lore Int
bindingNumber :: VName -> ForwardingM lore Int
bindingNumber VName
name = do
  Maybe Int
res <- (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int))
-> (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Entry lore -> Int) -> Maybe (Entry lore) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry lore -> Int
forall lore. Entry lore -> Int
entryNumber (Maybe (Entry lore) -> Maybe Int)
-> (TopDown lore -> Maybe (Entry lore))
-> TopDown lore
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  case Maybe Int
res of Just Int
n  -> Int -> ForwardingM lore Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
              Maybe Int
Nothing -> String -> ForwardingM lore Int
forall a. HasCallStack => String -> a
error (String -> ForwardingM lore Int) -> String -> ForwardingM lore Int
forall a b. (a -> b) -> a -> b
$ String
"bindingNumber: variable " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                         VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

deepen :: ForwardingM lore a -> ForwardingM lore a
deepen :: ForwardingM lore a -> ForwardingM lore a
deepen = (TopDown lore -> TopDown lore)
-> ForwardingM lore a -> ForwardingM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TopDown lore -> TopDown lore)
 -> ForwardingM lore a -> ForwardingM lore a)
-> (TopDown lore -> TopDown lore)
-> ForwardingM lore a
-> ForwardingM lore a
forall a b. (a -> b) -> a -> b
$ \TopDown lore
env -> TopDown lore
env { topDownDepth :: Int
topDownDepth = TopDown lore -> Int
forall lore. TopDown lore -> Int
topDownDepth TopDown lore
env Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 }

areAvailableBefore :: [SubExp] -> VName -> ForwardingM lore Bool
areAvailableBefore :: Result -> VName -> ForwardingM lore Bool
areAvailableBefore Result
ses VName
point = do
  Int
pointN <- VName -> ForwardingM lore Int
forall lore. VName -> ForwardingM lore Int
bindingNumber VName
point
  [Int]
nameNs <- (VName -> ForwardingM lore Int)
-> [VName] -> ForwardingM lore [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ForwardingM lore Int
forall lore. VName -> ForwardingM lore Int
bindingNumber ([VName] -> ForwardingM lore [Int])
-> [VName] -> ForwardingM lore [Int]
forall a b. (a -> b) -> a -> b
$ Result -> [VName]
subExpVars Result
ses
  Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> ForwardingM lore Bool) -> Bool -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
pointN) [Int]
nameNs

isInCurrentBody :: VName -> ForwardingM lore Bool
isInCurrentBody :: VName -> ForwardingM lore Bool
isInCurrentBody VName
name = do
  Int
current <- (TopDown lore -> Int) -> ForwardingM lore Int
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TopDown lore -> Int
forall lore. TopDown lore -> Int
topDownDepth
  Maybe Int
res <- (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int))
-> (TopDown lore -> Maybe Int) -> ForwardingM lore (Maybe Int)
forall a b. (a -> b) -> a -> b
$ (Entry lore -> Int) -> Maybe (Entry lore) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry lore -> Int
forall lore. Entry lore -> Int
entryDepth (Maybe (Entry lore) -> Maybe Int)
-> (TopDown lore -> Maybe (Entry lore))
-> TopDown lore
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  case Maybe Int
res of Just Int
d  -> Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> ForwardingM lore Bool) -> Bool -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$ Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
current
              Maybe Int
Nothing -> String -> ForwardingM lore Bool
forall a. HasCallStack => String -> a
error (String -> ForwardingM lore Bool)
-> String -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$ String
"isInCurrentBody: variable " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                         VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

isOptimisable :: VName -> ForwardingM lore Bool
isOptimisable :: VName -> ForwardingM lore Bool
isOptimisable VName
name = do
  Maybe Bool
res <- (TopDown lore -> Maybe Bool) -> ForwardingM lore (Maybe Bool)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Maybe Bool) -> ForwardingM lore (Maybe Bool))
-> (TopDown lore -> Maybe Bool) -> ForwardingM lore (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ (Entry lore -> Bool) -> Maybe (Entry lore) -> Maybe Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry lore -> Bool
forall lore. Entry lore -> Bool
entryOptimisable (Maybe (Entry lore) -> Maybe Bool)
-> (TopDown lore -> Maybe (Entry lore))
-> TopDown lore
-> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  case Maybe Bool
res of Just Bool
b  -> Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
b
              Maybe Bool
Nothing -> String -> ForwardingM lore Bool
forall a. HasCallStack => String -> a
error (String -> ForwardingM lore Bool)
-> String -> ForwardingM lore Bool
forall a b. (a -> b) -> a -> b
$ String
"isOptimisable: variable " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                         VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found."

seenVar :: VName -> ForwardingM lore ()
seenVar :: VName -> ForwardingM lore ()
seenVar VName
name = do
  Names
aliases <- (TopDown lore -> Names) -> ForwardingM lore Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TopDown lore -> Names) -> ForwardingM lore Names)
-> (TopDown lore -> Names) -> ForwardingM lore Names
forall a b. (a -> b) -> a -> b
$
             Names -> (Entry lore -> Names) -> Maybe (Entry lore) -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty Entry lore -> Names
forall lore. Entry lore -> Names
entryAliases (Maybe (Entry lore) -> Names)
-> (TopDown lore -> Maybe (Entry lore)) -> TopDown lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
             VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (TopDown lore -> Map VName (Entry lore))
-> TopDown lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TopDown lore -> Map VName (Entry lore)
forall lore. TopDown lore -> VTable lore
topDownTable
  BottomUp lore -> ForwardingM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (BottomUp lore -> ForwardingM lore ())
-> BottomUp lore -> ForwardingM lore ()
forall a b. (a -> b) -> a -> b
$ BottomUp lore
forall a. Monoid a => a
mempty { bottomUpSeen :: Names
bottomUpSeen = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases }

tapBottomUp :: ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp :: ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp ForwardingM lore a
m = do (a
x,BottomUp lore
bup) <- ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ForwardingM lore a
m
                   (a, BottomUp lore) -> ForwardingM lore (a, BottomUp lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, BottomUp lore
bup)

maybeForward :: Constraints lore =>
                VName
             -> VName -> LetAttr (Aliases lore) -> Certificates -> VName -> SubExp
             -> ForwardingM lore Bool
maybeForward :: VName
-> VName
-> LetAttr (Aliases lore)
-> Certificates
-> VName
-> SubExp
-> ForwardingM lore Bool
maybeForward VName
v VName
dest_nm LetAttr (Aliases lore)
dest_attr Certificates
cs VName
src SubExp
i = do
  -- Checks condition (2)
  Bool
available <- [SubExp
i,VName -> SubExp
Var VName
src] Result -> VName -> ForwardingM lore Bool
forall lore. Result -> VName -> ForwardingM lore Bool
`areAvailableBefore` VName
v
  -- ...subcondition, the certificates must also.
  Bool
certs_available <- (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certificates -> Names
forall a. FreeIn a => a -> Names
freeIn Certificates
cs) Result -> VName -> ForwardingM lore Bool
forall lore. Result -> VName -> ForwardingM lore Bool
`areAvailableBefore` VName
v
  -- Check condition (3)
  Bool
samebody <- VName -> ForwardingM lore Bool
forall lore. VName -> ForwardingM lore Bool
isInCurrentBody VName
v
  -- Check condition (6)
  Bool
optimisable <- VName -> ForwardingM lore Bool
forall lore. VName -> ForwardingM lore Bool
isOptimisable VName
v
  Bool
not_prim <- Bool -> Bool
not (Bool -> Bool) -> (Type -> Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> ForwardingM lore Type -> ForwardingM lore Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ForwardingM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  if Bool
available Bool -> Bool -> Bool
&& Bool
certs_available Bool -> Bool -> Bool
&& Bool
samebody Bool -> Bool -> Bool
&& Bool
optimisable Bool -> Bool -> Bool
&& Bool
not_prim then do
    let fwd :: DesiredUpdate (VarAliases, LetAttr lore)
fwd = VName
-> (VarAliases, LetAttr lore)
-> Certificates
-> VName
-> [DimIndex SubExp]
-> VName
-> DesiredUpdate (VarAliases, LetAttr lore)
forall attr.
VName
-> attr
-> Certificates
-> VName
-> [DimIndex SubExp]
-> VName
-> DesiredUpdate attr
DesiredUpdate VName
dest_nm (VarAliases, LetAttr lore)
LetAttr (Aliases lore)
dest_attr Certificates
cs VName
src [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i] VName
v
    BottomUp lore -> ForwardingM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell BottomUp Any
forall a. Monoid a => a
mempty { forwardThese :: [DesiredUpdate (LetAttr (Aliases lore))]
forwardThese = [DesiredUpdate (VarAliases, LetAttr lore)
DesiredUpdate (LetAttr (Aliases lore))
fwd] }
    Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    else Bool -> ForwardingM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False