{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998

\section[FloatOut]{Float bindings outwards (towards the top level)}

``Long-distance'' floating of bindings towards the top level.
-}

{-# LANGUAGE CPP #-}

module FloatOut ( floatOutwards ) where

import GhcPrelude

import CoreSyn
import CoreUtils
import MkCore
import CoreArity        ( etaExpand )
import CoreMonad        ( FloatOutSwitches(..) )

import DynFlags
import ErrUtils         ( dumpIfSet_dyn )
import Id               ( Id, idArity, idType, isBottomingId,
                          isJoinId, isJoinId_maybe )
import SetLevels
import UniqSupply       ( UniqSupply )
import Bag
import Util
import Maybes
import Outputable
import Type
import qualified Data.IntMap as M

import Data.List        ( partition )

#include "HsVersions.h"

{-
        -----------------
        Overall game plan
        -----------------

The Big Main Idea is:

        To float out sub-expressions that can thereby get outside
        a non-one-shot value lambda, and hence may be shared.


To achieve this we may need to do two things:

   a) Let-bind the sub-expression:

        f (g x)  ==>  let lvl = f (g x) in lvl

      Now we can float the binding for 'lvl'.

   b) More than that, we may need to abstract wrt a type variable

        \x -> ... /\a -> let v = ...a... in ....

      Here the binding for v mentions 'a' but not 'x'.  So we
      abstract wrt 'a', to give this binding for 'v':

            vp = /\a -> ...a...
            v  = vp a

      Now the binding for vp can float out unimpeded.
      I can't remember why this case seemed important enough to
      deal with, but I certainly found cases where important floats
      didn't happen if we did not abstract wrt tyvars.

With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level.  Whether or not this
happens is controlled by a flag.


Random comments
~~~~~~~~~~~~~~~

At the moment we never float a binding out to between two adjacent
lambdas.  For example:

@
        \x y -> let t = x+x in ...
===>
        \x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.

But there's a case I've seen where this might not be true.  Consider:
@
elEm2 x ys
  = elem' x ys
  where
    elem' _ []  = False
    elem' x (y:ys)      = x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
        \deq x ys -> let eq = eqFromEqDict deq in ...
@
which might usefully be separated to
@
        \deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe.  We don't do this at the moment.

Note [Join points]
~~~~~~~~~~~~~~~~~~
Every occurrence of a join point must be a tail call (see Note [Invariants on
join points] in CoreSyn), so we must be careful with how far we float them. The
mechanism for doing so is the *join ceiling*, detailed in Note [Join ceiling]
in SetLevels. For us, the significance is that a binder might be marked to be
dropped at the nearest boundary between tail calls and non-tail calls. For
example:

  (< join j = ... in
     let x = < ... > in
     case < ... > of
       A -> ...
       B -> ...
   >) < ... > < ... >

Here the join ceilings are marked with angle brackets. Either side of an
application is a join ceiling, as is the scrutinee position of a case
expression or the RHS of a let binding (but not a join point).

Why do we *want* do float join points at all? After all, they're never
allocated, so there's no sharing to be gained by floating them. However, the
other benefit of floating is making RHSes small, and this can have a significant
impact. In particular, stream fusion has been known to produce nested loops like
this:

  joinrec j1 x1 =
    joinrec j2 x2 =
      joinrec j3 x3 = ... jump j1 (x3 + 1) ... jump j2 (x3 + 1) ...
      in jump j3 x2
    in jump j2 x1
  in jump j1 x

(Assume x1 and x2 do *not* occur free in j3.)

Here j1 and j2 are wholly superfluous---each of them merely forwards its
argument to j3. Since j3 only refers to x3, we can float j2 and j3 to make
everything one big mutual recursion:

  joinrec j1 x1 = jump j2 x1
          j2 x2 = jump j3 x2
          j3 x3 = ... jump j1 (x3 + 1) ... jump j2 (x3 + 1) ...
  in jump j1 x

Now the simplifier will happily inline the trivial j1 and j2, leaving only j3.
Without floating, we're stuck with three loops instead of one.

************************************************************************
*                                                                      *
\subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
*                                                                      *
************************************************************************
-}

floatOutwards :: FloatOutSwitches
              -> DynFlags
              -> UniqSupply
              -> CoreProgram -> IO CoreProgram

floatOutwards :: FloatOutSwitches
-> DynFlags -> UniqSupply -> CoreProgram -> IO CoreProgram
floatOutwards float_sws :: FloatOutSwitches
float_sws dflags :: DynFlags
dflags us :: UniqSupply
us pgm :: CoreProgram
pgm
  = do {
        let { annotated_w_levels :: [LevelledBind]
annotated_w_levels = FloatOutSwitches -> CoreProgram -> UniqSupply -> [LevelledBind]
setLevels FloatOutSwitches
float_sws CoreProgram
pgm UniqSupply
us ;
              (fss :: [FloatStats]
fss, binds_s' :: [Bag CoreBind]
binds_s')    = [(FloatStats, Bag CoreBind)] -> ([FloatStats], [Bag CoreBind])
forall a b. [(a, b)] -> ([a], [b])
unzip ((LevelledBind -> (FloatStats, Bag CoreBind))
-> [LevelledBind] -> [(FloatStats, Bag CoreBind)]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind [LevelledBind]
annotated_w_levels)
            } ;

        DynFlags -> DumpFlag -> String -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_verbose_core2core "Levels added:"
                  ([SDoc] -> SDoc
vcat ((LevelledBind -> SDoc) -> [LevelledBind] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr [LevelledBind]
annotated_w_levels));

        let { (tlets :: Int
tlets, ntlets :: Int
ntlets, lams :: Int
lams) = FloatStats -> (Int, Int, Int)
get_stats ([FloatStats] -> FloatStats
sum_stats [FloatStats]
fss) };

        DynFlags -> DumpFlag -> String -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_dump_simpl_stats "FloatOut stats:"
                ([SDoc] -> SDoc
hcat [ Int -> SDoc
int Int
tlets,  String -> SDoc
text " Lets floated to top level; ",
                        Int -> SDoc
int Int
ntlets, String -> SDoc
text " Lets floated elsewhere; from ",
                        Int -> SDoc
int Int
lams,   String -> SDoc
text " Lambda groups"]);

        CoreProgram -> IO CoreProgram
forall (m :: * -> *) a. Monad m => a -> m a
return (Bag CoreBind -> CoreProgram
forall a. Bag a -> [a]
bagToList ([Bag CoreBind] -> Bag CoreBind
forall a. [Bag a] -> Bag a
unionManyBags [Bag CoreBind]
binds_s'))
    }

floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind bind :: LevelledBind
bind
  = case (LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind LevelledBind
bind) of { (fs :: FloatStats
fs, floats :: FloatBinds
floats, bind' :: CoreProgram
bind') ->
    let float_bag :: Bag CoreBind
float_bag = FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
floats
    in case CoreProgram
bind' of
      -- bind' can't have unlifted values or join points, so can only be one
      -- value bind, rec or non-rec (see comment on floatBind)
      [Rec prs :: [(CoreBndr, Expr CoreBndr)]
prs]    -> (FloatStats
fs, CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag ([(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
addTopFloatPairs Bag CoreBind
float_bag [(CoreBndr, Expr CoreBndr)]
prs)))
      [NonRec b :: CoreBndr
b e :: Expr CoreBndr
e] -> (FloatStats
fs, Bag CoreBind
float_bag Bag CoreBind -> CoreBind -> Bag CoreBind
forall a. Bag a -> a -> Bag a
`snocBag` CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b Expr CoreBndr
e)
      _            -> String -> SDoc -> (FloatStats, Bag CoreBind)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "floatTopBind" (CoreProgram -> SDoc
forall a. Outputable a => a -> SDoc
ppr CoreProgram
bind') }

{-
************************************************************************
*                                                                      *
\subsection[FloatOut-Bind]{Floating in a binding (the business end)}
*                                                                      *
************************************************************************
-}

floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
  -- Returns a list with either
  --   * A single non-recursive binding (value or join point), or
  --   * The following, in order:
  --     * Zero or more non-rec unlifted bindings
  --     * One or both of:
  --       * A recursive group of join binds
  --       * A recursive group of value binds
  -- See Note [Floating out of Rec rhss] for why things get arranged this way.
floatBind :: LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind (NonRec (TB var :: CoreBndr
var _) rhs :: Expr (TaggedBndr FloatSpec)
rhs)
  = case (CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs CoreBndr
var Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->

        -- A tiresome hack:
        -- see Note [Bottoming floats: eta expansion] in SetLevels
    let rhs'' :: Expr CoreBndr
rhs'' | CoreBndr -> Bool
isBottomingId CoreBndr
var = Int -> Expr CoreBndr -> Expr CoreBndr
etaExpand (CoreBndr -> Int
idArity CoreBndr
var) Expr CoreBndr
rhs'
              | Bool
otherwise         = Expr CoreBndr
rhs'

    in (FloatStats
fs, FloatBinds
rhs_floats, [CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
var Expr CoreBndr
rhs'']) }

floatBind (Rec pairs :: [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs)
  = case ((TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
 -> (FloatStats, FloatBinds,
     ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])))
-> [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
-> (FloatStats, FloatBinds,
    [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds,
    ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)]))
do_pair [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, new_pairs :: [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])]
new_pairs) ->
    let (new_ul_pairss :: [[(CoreBndr, Expr CoreBndr)]]
new_ul_pairss, new_other_pairss :: [[(CoreBndr, Expr CoreBndr)]]
new_other_pairss) = [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])]
-> ([[(CoreBndr, Expr CoreBndr)]], [[(CoreBndr, Expr CoreBndr)]])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])]
new_pairs
        (new_join_pairs :: [(CoreBndr, Expr CoreBndr)]
new_join_pairs, new_l_pairs :: [(CoreBndr, Expr CoreBndr)]
new_l_pairs)     = ((CoreBndr, Expr CoreBndr) -> Bool)
-> [(CoreBndr, Expr CoreBndr)]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (CoreBndr -> Bool
isJoinId (CoreBndr -> Bool)
-> ((CoreBndr, Expr CoreBndr) -> CoreBndr)
-> (CoreBndr, Expr CoreBndr)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CoreBndr, Expr CoreBndr) -> CoreBndr
forall a b. (a, b) -> a
fst)
                                                      ([[(CoreBndr, Expr CoreBndr)]] -> [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(CoreBndr, Expr CoreBndr)]]
new_other_pairss)
        -- Can't put the join points and the values in the same rec group
        new_rec_binds :: CoreProgram
new_rec_binds | [(CoreBndr, Expr CoreBndr)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(CoreBndr, Expr CoreBndr)]
new_join_pairs = [ [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_l_pairs    ]
                      | [(CoreBndr, Expr CoreBndr)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(CoreBndr, Expr CoreBndr)]
new_l_pairs    = [ [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_join_pairs ]
                      | Bool
otherwise           = [ [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_l_pairs
                                              , [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_join_pairs ]
        new_non_rec_binds :: CoreProgram
new_non_rec_binds = [ CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b Expr CoreBndr
e | (b :: CoreBndr
b, e :: Expr CoreBndr
e) <- [[(CoreBndr, Expr CoreBndr)]] -> [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(CoreBndr, Expr CoreBndr)]]
new_ul_pairss ]
    in
    (FloatStats
fs, FloatBinds
rhs_floats, CoreProgram
new_non_rec_binds CoreProgram -> CoreProgram -> CoreProgram
forall a. [a] -> [a] -> [a]
++ CoreProgram
new_rec_binds) }
  where
    do_pair :: (LevelledBndr, LevelledExpr)
            -> (FloatStats, FloatBinds,
                ([(Id,CoreExpr)],  -- Non-recursive unlifted value bindings
                 [(Id,CoreExpr)])) -- Join points and lifted value bindings
    do_pair :: (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds,
    ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)]))
do_pair (TB name :: CoreBndr
name spec :: FloatSpec
spec, rhs :: Expr (TaggedBndr FloatSpec)
rhs)
      | Level -> Bool
isTopLvl Level
dest_lvl  -- See Note [floatBind for top level]
      = case (CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs CoreBndr
name Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
        (FloatStats
fs, FloatBinds
emptyFloats, ([], Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
addTopFloatPairs (FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
rhs_floats)
                                                [(CoreBndr
name, Expr CoreBndr
rhs')]))}
      | Bool
otherwise         -- Note [Floating out of Rec rhss]
      = case (CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs CoreBndr
name Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
        case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
dest_lvl FloatBinds
rhs_floats) of { (rhs_floats' :: FloatBinds
rhs_floats', heres :: Bag FloatBind
heres) ->
        case (Bag FloatBind
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
splitRecFloats Bag FloatBind
heres) of { (ul_pairs :: [(CoreBndr, Expr CoreBndr)]
ul_pairs, pairs :: [(CoreBndr, Expr CoreBndr)]
pairs, case_heres :: Bag FloatBind
case_heres) ->
        let pairs' :: [(CoreBndr, Expr CoreBndr)]
pairs' = (CoreBndr
name, Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
installUnderLambdas Bag FloatBind
case_heres Expr CoreBndr
rhs') (CoreBndr, Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. a -> [a] -> [a]
: [(CoreBndr, Expr CoreBndr)]
pairs in
        (FloatStats
fs, FloatBinds
rhs_floats', ([(CoreBndr, Expr CoreBndr)]
ul_pairs, [(CoreBndr, Expr CoreBndr)]
pairs')) }}}
      where
        dest_lvl :: Level
dest_lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
spec

splitRecFloats :: Bag FloatBind
               -> ([(Id,CoreExpr)], -- Non-recursive unlifted value bindings
                   [(Id,CoreExpr)], -- Join points and lifted value bindings
                   Bag FloatBind)   -- A tail of further bindings
-- The "tail" begins with a case
-- See Note [Floating out of Rec rhss]
splitRecFloats :: Bag FloatBind
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
splitRecFloats fs :: Bag FloatBind
fs
  = [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
go [] [] (Bag FloatBind -> [FloatBind]
forall a. Bag a -> [a]
bagToList Bag FloatBind
fs)
  where
    go :: [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
go ul_prs :: [(CoreBndr, Expr CoreBndr)]
ul_prs prs :: [(CoreBndr, Expr CoreBndr)]
prs (FloatLet (NonRec b :: CoreBndr
b r :: Expr CoreBndr
r) : fs :: [FloatBind]
fs) | HasDebugCallStack => Type -> Bool
Type -> Bool
isUnliftedType (CoreBndr -> Type
idType CoreBndr
b)
                                               , Bool -> Bool
not (CoreBndr -> Bool
isJoinId CoreBndr
b)
                                               = [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
go ((CoreBndr
b,Expr CoreBndr
r)(CoreBndr, Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. a -> [a] -> [a]
:[(CoreBndr, Expr CoreBndr)]
ul_prs) [(CoreBndr, Expr CoreBndr)]
prs [FloatBind]
fs
                                               | Bool
otherwise
                                               = [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
go [(CoreBndr, Expr CoreBndr)]
ul_prs ((CoreBndr
b,Expr CoreBndr
r)(CoreBndr, Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. a -> [a] -> [a]
:[(CoreBndr, Expr CoreBndr)]
prs) [FloatBind]
fs
    go ul_prs :: [(CoreBndr, Expr CoreBndr)]
ul_prs prs :: [(CoreBndr, Expr CoreBndr)]
prs (FloatLet (Rec prs' :: [(CoreBndr, Expr CoreBndr)]
prs')   : fs :: [FloatBind]
fs) = [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
    Bag FloatBind)
go [(CoreBndr, Expr CoreBndr)]
ul_prs ([(CoreBndr, Expr CoreBndr)]
prs' [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. [a] -> [a] -> [a]
++ [(CoreBndr, Expr CoreBndr)]
prs) [FloatBind]
fs
    go ul_prs :: [(CoreBndr, Expr CoreBndr)]
ul_prs prs :: [(CoreBndr, Expr CoreBndr)]
prs fs :: [FloatBind]
fs                           = ([(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. [a] -> [a]
reverse [(CoreBndr, Expr CoreBndr)]
ul_prs, [(CoreBndr, Expr CoreBndr)]
prs,
                                                  [FloatBind] -> Bag FloatBind
forall a. [a] -> Bag a
listToBag [FloatBind]
fs)
                                                   -- Order only matters for
                                                   -- non-rec

installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
-- Note [Floating out of Rec rhss]
installUnderLambdas :: Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
installUnderLambdas floats :: Bag FloatBind
floats e :: Expr CoreBndr
e
  | Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
floats = Expr CoreBndr
e
  | Bool
otherwise         = Expr CoreBndr -> Expr CoreBndr
go Expr CoreBndr
e
  where
    go :: Expr CoreBndr -> Expr CoreBndr
go (Lam b :: CoreBndr
b e :: Expr CoreBndr
e)                 = CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
b (Expr CoreBndr -> Expr CoreBndr
go Expr CoreBndr
e)
    go e :: Expr CoreBndr
e                         = Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install Bag FloatBind
floats Expr CoreBndr
e

---------------
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList :: (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (FloatStats
zeroStats, FloatBinds
emptyFloats, [])
floatList f :: a -> (FloatStats, FloatBinds, b)
f (a :: a
a:as :: [a]
as) = case a -> (FloatStats, FloatBinds, b)
f a
a            of { (fs_a :: FloatStats
fs_a,  binds_a :: FloatBinds
binds_a,  b :: b
b)  ->
                     case (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
f [a]
as of { (fs_as :: FloatStats
fs_as, binds_as :: FloatBinds
binds_as, bs :: [b]
bs) ->
                     (FloatStats
fs_a FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fs_as, FloatBinds
binds_a FloatBinds -> FloatBinds -> FloatBinds
`plusFloats`  FloatBinds
binds_as, b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs) }}

{-
Note [Floating out of Rec rhss]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider   Rec { f<1,0> = \xy. body }
From the body we may get some floats. The ones with level <1,0> must
stay here, since they may mention f.  Ideally we'd like to make them
part of the Rec block pairs -- but we can't if there are any
FloatCases involved.

Nor is it a good idea to dump them in the rhs, but outside the lambda
    f = case x of I# y -> \xy. body
because now f's arity might get worse, which is Not Good. (And if
there's an SCC around the RHS it might not get better again.
See Trac #5342.)

So, gruesomely, we split the floats into
 * the outer FloatLets, which can join the Rec, and
 * an inner batch starting in a FloatCase, which are then
   pushed *inside* the lambdas.
This loses full-laziness the rare situation where there is a
FloatCase and a Rec interacting.

If there are unlifted FloatLets (that *aren't* join points) among the floats,
we can't add them to the recursive group without angering Core Lint, but since
they must be ok-for-speculation, they can't actually be making any recursive
calls, so we can safely pull them out and keep them non-recursive.

(Why is something getting floated to <1,0> that doesn't make a recursive call?
The case that came up in testing was that f *and* the unlifted binding were
getting floated *to the same place*:

  \x<2,0> ->
    ... <3,0>
    letrec { f<F<2,0>> =
      ... let x'<F<2,0>> = x +# 1# in ...
    } in ...

Everything gets labeled "float to <2,0>" because it all depends on x, but this
makes f and x' look mutually recursive when they're not.

The test was shootout/k-nucleotide, as compiled using commit 47d5dd68 on the
wip/join-points branch.

TODO: This can probably be solved somehow in SetLevels. The difference between
"this *is at* level <2,0>" and "this *depends on* level <2,0>" is very
important.)

Note [floatBind for top level]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We may have a *nested* binding whose destination level is (FloatMe tOP_LEVEL), thus
         letrec { foo <0,0> = .... (let bar<0,0> = .. in ..) .... }
The binding for bar will be in the "tops" part of the floating binds,
and thus not partioned by floatBody.

We could perhaps get rid of the 'tops' component of the floating binds,
but this case works just as well.


************************************************************************

\subsection[FloatOut-Expr]{Floating in expressions}
*                                                                      *
************************************************************************
-}

floatBody :: Level
          -> LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)

floatBody :: Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody lvl :: Level
lvl arg :: Expr (TaggedBndr FloatSpec)
arg       -- Used rec rhss, and case-alternative rhss
  = case (Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
arg) of { (fsa :: FloatStats
fsa, floats :: FloatBinds
floats, arg' :: Expr CoreBndr
arg') ->
    case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
lvl FloatBinds
floats) of { (floats' :: FloatBinds
floats', heres :: Bag FloatBind
heres) ->
        -- Dump bindings are bound here
    (FloatStats
fsa, FloatBinds
floats', Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install Bag FloatBind
heres Expr CoreBndr
arg') }}

-----------------

{- Note [Floating past breakpoints]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

We used to disallow floating out of breakpoint ticks (see #10052). However, I
think this is too restrictive.

Consider the case of an expression scoped over by a breakpoint tick,

  tick<...> (let x = ... in f x)

In this case it is completely legal to float out x, despite the fact that
breakpoint ticks are scoped,

  let x = ... in (tick<...>  f x)

The reason here is that we know that the breakpoint will still be hit when the
expression is entered since the tick still scopes over the RHS.

-}

floatExpr :: LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)
floatExpr :: Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr (Var v :: CoreBndr
v)   = (FloatStats
zeroStats, FloatBinds
emptyFloats, CoreBndr -> Expr CoreBndr
forall b. CoreBndr -> Expr b
Var CoreBndr
v)
floatExpr (Type ty :: Type
ty) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Type -> Expr CoreBndr
forall b. Type -> Expr b
Type Type
ty)
floatExpr (Coercion co :: Coercion
co) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Coercion -> Expr CoreBndr
forall b. Coercion -> Expr b
Coercion Coercion
co)
floatExpr (Lit lit :: Literal
lit) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Literal -> Expr CoreBndr
forall b. Literal -> Expr b
Lit Literal
lit)

floatExpr (App e :: Expr (TaggedBndr FloatSpec)
e a :: Expr (TaggedBndr FloatSpec)
a)
  = case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr  Expr (TaggedBndr FloatSpec)
e) of { (fse :: FloatStats
fse, floats_e :: FloatBinds
floats_e, e' :: Expr CoreBndr
e') ->
    case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr  Expr (TaggedBndr FloatSpec)
a) of { (fsa :: FloatStats
fsa, floats_a :: FloatBinds
floats_a, a' :: Expr CoreBndr
a') ->
    (FloatStats
fse FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fsa, FloatBinds
floats_e FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
floats_a, Expr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Expr b -> Expr b -> Expr b
App Expr CoreBndr
e' Expr CoreBndr
a') }}

floatExpr lam :: Expr (TaggedBndr FloatSpec)
lam@(Lam (TB _ lam_spec :: FloatSpec
lam_spec) _)
  = let (bndrs_w_lvls :: [TaggedBndr FloatSpec]
bndrs_w_lvls, body :: Expr (TaggedBndr FloatSpec)
body) = Expr (TaggedBndr FloatSpec)
-> ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall b. Expr b -> ([b], Expr b)
collectBinders Expr (TaggedBndr FloatSpec)
lam
        bndrs :: [CoreBndr]
bndrs                = [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr FloatSpec]
bndrs_w_lvls]
        bndr_lvl :: Level
bndr_lvl             = Level -> Level
asJoinCeilLvl (FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec)
        -- All the binders have the same level
        -- See SetLevels.lvlLamBndrs
        -- Use asJoinCeilLvl to make this the join ceiling
    in
    case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
bndr_lvl Expr (TaggedBndr FloatSpec)
body) of { (fs :: FloatStats
fs, floats :: FloatBinds
floats, body' :: Expr CoreBndr
body') ->
    (FloatStats -> FloatBinds -> FloatStats
add_to_stats FloatStats
fs FloatBinds
floats, FloatBinds
floats, [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
forall b. [b] -> Expr b -> Expr b
mkLams [CoreBndr]
bndrs Expr CoreBndr
body') }

floatExpr (Tick tickish :: Tickish CoreBndr
tickish expr :: Expr (TaggedBndr FloatSpec)
expr)
  | Tickish CoreBndr
tickish Tickish CoreBndr -> TickishScoping -> Bool
forall id. Tickish id -> TickishScoping -> Bool
`tickishScopesLike` TickishScoping
SoftScope -- not scoped, can just float
  = case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish Expr CoreBndr
expr') }

  | Bool -> Bool
not (Tickish CoreBndr -> Bool
forall id. Tickish id -> Bool
tickishCounts Tickish CoreBndr
tickish) Bool -> Bool -> Bool
|| Tickish CoreBndr -> Bool
forall id. Tickish id -> Bool
tickishCanSplit Tickish CoreBndr
tickish
  = case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
    let -- Annotate bindings floated outwards past an scc expression
        -- with the cc.  We mark that cc as "duplicated", though.
        annotated_defns :: FloatBinds
annotated_defns = Tickish CoreBndr -> FloatBinds -> FloatBinds
wrapTick (Tickish CoreBndr -> Tickish CoreBndr
forall id. Tickish id -> Tickish id
mkNoCount Tickish CoreBndr
tickish) FloatBinds
floating_defns
    in
    (FloatStats
fs, FloatBinds
annotated_defns, Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish Expr CoreBndr
expr') }

  -- Note [Floating past breakpoints]
  | Breakpoint{} <- Tickish CoreBndr
tickish
  = case (Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish Expr CoreBndr
expr') }

  | Bool
otherwise
  = String -> SDoc -> (FloatStats, FloatBinds, Expr CoreBndr)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "floatExpr tick" (Tickish CoreBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr Tickish CoreBndr
tickish)

floatExpr (Cast expr :: Expr (TaggedBndr FloatSpec)
expr co :: Coercion
co)
  = case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, Expr CoreBndr -> Coercion -> Expr CoreBndr
forall b. Expr b -> Coercion -> Expr b
Cast Expr CoreBndr
expr' Coercion
co) }

floatExpr (Let bind :: LevelledBind
bind body :: Expr (TaggedBndr FloatSpec)
body)
  = case FloatSpec
bind_spec of
      FloatMe dest_lvl :: Level
dest_lvl
        -> case (LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind LevelledBind
bind) of { (fsb :: FloatStats
fsb, bind_floats :: FloatBinds
bind_floats, binds' :: CoreProgram
binds') ->
           case (Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
body) of { (fse :: FloatStats
fse, body_floats :: FloatBinds
body_floats, body' :: Expr CoreBndr
body') ->
           let new_bind_floats :: FloatBinds
new_bind_floats = (FloatBinds -> FloatBinds -> FloatBinds)
-> FloatBinds -> [FloatBinds] -> FloatBinds
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBinds -> FloatBinds -> FloatBinds
plusFloats FloatBinds
emptyFloats
                                   ((CoreBind -> FloatBinds) -> CoreProgram -> [FloatBinds]
forall a b. (a -> b) -> [a] -> [b]
map (Level -> CoreBind -> FloatBinds
unitLetFloat Level
dest_lvl) CoreProgram
binds') in
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
new_bind_floats
                         FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , Expr CoreBndr
body') }}

      StayPut bind_lvl :: Level
bind_lvl  -- See Note [Avoiding unnecessary floating]
        -> case (LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind LevelledBind
bind)          of { (fsb :: FloatStats
fsb, bind_floats :: FloatBinds
bind_floats, binds' :: CoreProgram
binds') ->
           case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
body) of { (fse :: FloatStats
fse, body_floats :: FloatBinds
body_floats, body' :: Expr CoreBndr
body') ->
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , (CoreBind -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> CoreProgram -> Expr CoreBndr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let Expr CoreBndr
body' CoreProgram
binds' ) }}
  where
    bind_spec :: FloatSpec
bind_spec = case LevelledBind
bind of
                 NonRec (TB _ s :: FloatSpec
s) _     -> FloatSpec
s
                 Rec ((TB _ s :: FloatSpec
s, _) : _) -> FloatSpec
s
                 Rec []                -> String -> FloatSpec
forall a. String -> a
panic "floatExpr:rec"

floatExpr (Case scrut :: Expr (TaggedBndr FloatSpec)
scrut (TB case_bndr :: CoreBndr
case_bndr case_spec :: FloatSpec
case_spec) ty :: Type
ty alts :: [Alt (TaggedBndr FloatSpec)]
alts)
  = case FloatSpec
case_spec of
      FloatMe dest_lvl :: Level
dest_lvl  -- Case expression moves
        | [(con :: AltCon
con@(DataAlt {}), bndrs :: [TaggedBndr FloatSpec]
bndrs, rhs :: Expr (TaggedBndr FloatSpec)
rhs)] <- [Alt (TaggedBndr FloatSpec)]
alts
        -> case (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (fse :: FloatStats
fse, fde :: FloatBinds
fde, scrut' :: Expr CoreBndr
scrut') ->
           case                 Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
rhs   of { (fsb :: FloatStats
fsb, fdb :: FloatBinds
fdb, rhs' :: Expr CoreBndr
rhs') ->
           let
             float :: FloatBinds
float = Level
-> Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBinds
unitCaseFloat Level
dest_lvl Expr CoreBndr
scrut'
                          CoreBndr
case_bndr AltCon
con [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr FloatSpec]
bndrs]
           in
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsb, FloatBinds
fde FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
float FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fdb, Expr CoreBndr
rhs') }}
        | Bool
otherwise
        -> String -> SDoc -> (FloatStats, FloatBinds, Expr CoreBndr)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "Floating multi-case" ([Alt (TaggedBndr FloatSpec)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Alt (TaggedBndr FloatSpec)]
alts)

      StayPut bind_lvl :: Level
bind_lvl  -- Case expression stays put
        -> case (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (fse :: FloatStats
fse, fde :: FloatBinds
fde, scrut' :: Expr CoreBndr
scrut') ->
           case (Alt (TaggedBndr FloatSpec)
 -> (FloatStats, FloatBinds, (AltCon, [CoreBndr], Expr CoreBndr)))
-> [Alt (TaggedBndr FloatSpec)]
-> (FloatStats, FloatBinds, [(AltCon, [CoreBndr], Expr CoreBndr)])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (Level
-> Alt (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, (AltCon, [CoreBndr], Expr CoreBndr))
forall a t.
Level
-> (a, [TaggedBndr t], Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, (a, [CoreBndr], Expr CoreBndr))
float_alt Level
bind_lvl) [Alt (TaggedBndr FloatSpec)]
alts of { (fsa :: FloatStats
fsa, fda :: FloatBinds
fda, alts' :: [(AltCon, [CoreBndr], Expr CoreBndr)]
alts')  ->
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsa, FloatBinds
fda FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fde, Expr CoreBndr
-> CoreBndr
-> Type
-> [(AltCon, [CoreBndr], Expr CoreBndr)]
-> Expr CoreBndr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr CoreBndr
scrut' CoreBndr
case_bndr Type
ty [(AltCon, [CoreBndr], Expr CoreBndr)]
alts')
           }}
  where
    float_alt :: Level
-> (a, [TaggedBndr t], Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, (a, [CoreBndr], Expr CoreBndr))
float_alt bind_lvl :: Level
bind_lvl (con :: a
con, bs :: [TaggedBndr t]
bs, rhs :: Expr (TaggedBndr FloatSpec)
rhs)
        = case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
          (FloatStats
fs, FloatBinds
rhs_floats, (a
con, [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr t]
bs], Expr CoreBndr
rhs')) }

floatRhs :: CoreBndr
         -> LevelledExpr
         -> (FloatStats, FloatBinds, CoreExpr)
floatRhs :: CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs bndr :: CoreBndr
bndr rhs :: Expr (TaggedBndr FloatSpec)
rhs
  | Just join_arity :: Int
join_arity <- CoreBndr -> Maybe Int
isJoinId_maybe CoreBndr
bndr
  , Just (bndrs :: [TaggedBndr FloatSpec]
bndrs, body :: Expr (TaggedBndr FloatSpec)
body) <- Int
-> Expr (TaggedBndr FloatSpec)
-> [TaggedBndr FloatSpec]
-> Maybe ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall t a.
(Eq t, Num t) =>
t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect Int
join_arity Expr (TaggedBndr FloatSpec)
rhs []
  = case [TaggedBndr FloatSpec]
bndrs of
      []                -> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
      (TB _ lam_spec :: FloatSpec
lam_spec):_ ->
        let lvl :: Level
lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec in
        case Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
body of { (fs :: FloatStats
fs, floats :: FloatBinds
floats, body' :: Expr CoreBndr
body') ->
        (FloatStats
fs, FloatBinds
floats, [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
forall b. [b] -> Expr b -> Expr b
mkLams [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr FloatSpec]
bndrs] Expr CoreBndr
body') }
  | Bool
otherwise
  = (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
 -> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
  where
    try_collect :: t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect 0 expr :: Expr a
expr      acc :: [a]
acc = ([a], Expr a) -> Maybe ([a], Expr a)
forall a. a -> Maybe a
Just ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, Expr a
expr)
    try_collect n :: t
n (Lam b :: a
b e :: Expr a
e) acc :: [a]
acc = t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect (t
nt -> t -> t
forall a. Num a => a -> a -> a
-1) Expr a
e (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc)
    try_collect _ _         _   = Maybe ([a], Expr a)
forall a. Maybe a
Nothing

{-
Note [Avoiding unnecessary floating]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In general we want to avoid floating a let unnecessarily, because
it might worsen strictness:
    let
       x = ...(let y = e in y+y)....
Here y is demanded.  If we float it outside the lazy 'x=..' then
we'd have to zap its demand info, and it may never be restored.

So at a 'let' we leave the binding right where the are unless
the binding will escape a value lambda, e.g.

(\x -> let y = fac 100 in y)

That's what the partitionByMajorLevel does in the floatExpr (Let ...)
case.

Notice, though, that we must take care to drop any bindings
from the body of the let that depend on the staying-put bindings.

We used instead to do the partitionByMajorLevel on the RHS of an '=',
in floatRhs.  But that was quite tiresome.  We needed to test for
values or trival rhss, because (in particular) we don't want to insert
new bindings between the "=" and the "\".  E.g.
        f = \x -> let <bind> in <body>
We do not want
        f = let <bind> in \x -> <body>
(a) The simplifier will immediately float it further out, so we may
        as well do so right now; in general, keeping rhss as manifest
        values is good
(b) If a float-in pass follows immediately, it might add yet more
        bindings just after the '='.  And some of them might (correctly)
        be strict even though the 'let f' is lazy, because f, being a value,
        gets its demand-info zapped by the simplifier.
And even all that turned out to be very fragile, and broke
altogether when profiling got in the way.

So now we do the partition right at the (Let..) itself.

************************************************************************
*                                                                      *
\subsection{Utility bits for floating stats}
*                                                                      *
************************************************************************

I didn't implement this with unboxed numbers.  I don't want to be too
strict in this stuff, as it is rarely turned on.  (WDP 95/09)
-}

data FloatStats
  = FlS Int  -- Number of top-floats * lambda groups they've been past
        Int  -- Number of non-top-floats * lambda groups they've been past
        Int  -- Number of lambda (groups) seen

get_stats :: FloatStats -> (Int, Int, Int)
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a :: Int
a b :: Int
b c :: Int
c) = (Int
a, Int
b, Int
c)

zeroStats :: FloatStats
zeroStats :: FloatStats
zeroStats = Int -> Int -> Int -> FloatStats
FlS 0 0 0

sum_stats :: [FloatStats] -> FloatStats
sum_stats :: [FloatStats] -> FloatStats
sum_stats xs :: [FloatStats]
xs = (FloatStats -> FloatStats -> FloatStats)
-> FloatStats -> [FloatStats] -> FloatStats
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
zeroStats [FloatStats]
xs

add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 :: Int
a1 b1 :: Int
b1 c1 :: Int
c1) (FlS a2 :: Int
a2 b2 :: Int
b2 c2 :: Int
c2)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a2) (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b2) (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)

add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS a :: Int
a b :: Int
b c :: Int
c) (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils others :: MajorEnv
others)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag CoreBind -> Int
forall a. Bag a -> Int
lengthBag Bag CoreBind
tops)
        (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag Bag FloatBind
ceils Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
others))
        (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)

{-
************************************************************************
*                                                                      *
\subsection{Utility bits for floating}
*                                                                      *
************************************************************************

Note [Representation of FloatBinds]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The FloatBinds types is somewhat important.  We can get very large numbers
of floating bindings, often all destined for the top level.  A typical example
is     x = [4,2,5,2,5, .... ]
Then we get lots of small expressions like (fromInteger 4), which all get
lifted to top level.

The trouble is that
  (a) we partition these floating bindings *at every binding site*
  (b) SetLevels introduces a new bindings site for every float
So we had better not look at each binding at each binding site!

That is why MajorEnv is represented as a finite map.

We keep the bindings destined for the *top* level separate, because
we float them out even if they don't escape a *value* lambda; see
partitionByMajorLevel.
-}

type FloatLet = CoreBind        -- INVARIANT: a FloatLet is always lifted
type MajorEnv = M.IntMap MinorEnv         -- Keyed by major level
type MinorEnv = M.IntMap (Bag FloatBind)  -- Keyed by minor level

data FloatBinds  = FB !(Bag FloatLet)           -- Destined for top level
                      !(Bag FloatBind)          -- Destined for join ceiling
                      !MajorEnv                 -- Other levels
     -- See Note [Representation of FloatBinds]

instance Outputable FloatBinds where
  ppr :: FloatBinds -> SDoc
ppr (FB fbs :: Bag CoreBind
fbs ceils :: Bag FloatBind
ceils defs :: MajorEnv
defs)
      = String -> SDoc
text "FB" SDoc -> SDoc -> SDoc
<+> (SDoc -> SDoc
braces (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
           [ String -> SDoc
text "tops ="     SDoc -> SDoc -> SDoc
<+> Bag CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag CoreBind
fbs
           , String -> SDoc
text "ceils ="    SDoc -> SDoc -> SDoc
<+> Bag FloatBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag FloatBind
ceils
           , String -> SDoc
text "non-tops =" SDoc -> SDoc -> SDoc
<+> MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs ])

flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defs :: MajorEnv
defs)
  = ASSERT2( isEmptyBag (flattenMajor defs), ppr defs )
    ASSERT2( isEmptyBag ceils, ppr ceils )
    Bag CoreBind
tops

addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs :: Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
addTopFloatPairs float_bag :: Bag CoreBind
float_bag prs :: [(CoreBndr, Expr CoreBndr)]
prs
  = (CoreBind
 -> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)])
-> [(CoreBndr, Expr CoreBndr)]
-> Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)]
forall a r. (a -> r -> r) -> r -> Bag a -> r
foldrBag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add [(CoreBndr, Expr CoreBndr)]
prs Bag CoreBind
float_bag
  where
    add :: Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add (NonRec b :: a
b r :: Expr a
r) prs :: [(a, Expr a)]
prs  = (a
b,Expr a
r)(a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
:[(a, Expr a)]
prs
    add (Rec prs1 :: [(a, Expr a)]
prs1)   prs2 :: [(a, Expr a)]
prs2 = [(a, Expr a)]
prs1 [(a, Expr a)] -> [(a, Expr a)] -> [(a, Expr a)]
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)]
prs2

flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = (MinorEnv -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MajorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr (Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> (MinorEnv -> Bag FloatBind)
-> MinorEnv
-> Bag FloatBind
-> Bag FloatBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MinorEnv -> Bag FloatBind
flattenMinor) Bag FloatBind
forall a. Bag a
emptyBag

flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MinorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags Bag FloatBind
forall a. Bag a
emptyBag

emptyFloats :: FloatBinds
emptyFloats :: FloatBinds
emptyFloats = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty

unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat :: Level
-> Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBinds
unitCaseFloat (Level major :: Int
major minor :: Int
minor t :: LevelType
t) e :: Expr CoreBndr
e b :: CoreBndr
b con :: AltCon
con bs :: [CoreBndr]
bs
  | LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
  | Bool
otherwise
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> MinorEnv -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> MinorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBind
FloatCase Expr CoreBndr
e CoreBndr
b AltCon
con [CoreBndr]
bs)

unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat :: Level -> CoreBind -> FloatBinds
unitLetFloat lvl :: Level
lvl@(Level major :: Int
major minor :: Int
minor t :: LevelType
t) b :: CoreBind
b
  | Level -> Bool
isTopLvl Level
lvl     = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag CoreBind
b) Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
  | LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
  | Bool
otherwise        = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> MinorEnv -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major
                                              (Int -> Bag FloatBind -> MinorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (CoreBind -> FloatBind
FloatLet CoreBind
b)

plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB t1 :: Bag CoreBind
t1 c1 :: Bag FloatBind
c1 l1 :: MajorEnv
l1) (FB t2 :: Bag CoreBind
t2 c2 :: Bag FloatBind
c2 l2 :: MajorEnv
l2)
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (Bag CoreBind
t1 Bag CoreBind -> Bag CoreBind -> Bag CoreBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag CoreBind
t2) (Bag FloatBind
c1 Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
c2) (MajorEnv
l1 MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` MajorEnv
l2)

plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = (MinorEnv -> MinorEnv -> MinorEnv)
-> MajorEnv -> MajorEnv -> MajorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith MinorEnv -> MinorEnv -> MinorEnv
plusMinor

plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> MinorEnv -> MinorEnv -> MinorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags

install :: Bag FloatBind -> CoreExpr -> CoreExpr
install :: Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install defn_groups :: Bag FloatBind
defn_groups expr :: Expr CoreBndr
expr
  = (FloatBind -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> Bag FloatBind -> Expr CoreBndr
forall a r. (a -> r -> r) -> r -> Bag a -> r
foldrBag FloatBind -> Expr CoreBndr -> Expr CoreBndr
wrapFloat Expr CoreBndr
expr Bag FloatBind
defn_groups

partitionByLevel
        :: Level                -- Partitioning level
        -> FloatBinds           -- Defns to be divided into 2 piles...
        -> (FloatBinds,         -- Defns  with level strictly < partition level,
            Bag FloatBind)      -- The rest

{-
--       ---- partitionByMajorLevel ----
-- Float it if we escape a value lambda,
--     *or* if we get to the top level
--     *or* if it's a case-float and its minor level is < current
--
-- If we can get to the top level, say "yes" anyway. This means that
--      x = f e
-- transforms to
--    lvl = e
--    x = f lvl
-- which is as it should be

partitionByMajorLevel (Level major _) (FB tops defns)
  = (FB tops outer, heres `unionBags` flattenMajor inner)
  where
    (outer, mb_heres, inner) = M.splitLookup major defns
    heres = case mb_heres of
               Nothing -> emptyBag
               Just h  -> flattenMinor h
-}

partitionByLevel :: Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel (Level major :: Int
major minor :: Int
minor typ :: LevelType
typ) (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defns :: MajorEnv
defns)
  = (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
ceils' (MajorEnv
outer_maj MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` Int -> MinorEnv -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major MinorEnv
outer_min),
     Bag FloatBind
here_min Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
here_ceil
              Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MinorEnv -> Bag FloatBind
flattenMinor MinorEnv
inner_min
              Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
inner_maj)

  where
    (outer_maj :: MajorEnv
outer_maj, mb_here_maj :: Maybe MinorEnv
mb_here_maj, inner_maj :: MajorEnv
inner_maj) = Int -> MajorEnv -> (MajorEnv, Maybe MinorEnv, MajorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
major MajorEnv
defns
    (outer_min :: MinorEnv
outer_min, mb_here_min :: Maybe (Bag FloatBind)
mb_here_min, inner_min :: MinorEnv
inner_min) = case Maybe MinorEnv
mb_here_maj of
                                            Nothing -> (MinorEnv
forall a. IntMap a
M.empty, Maybe (Bag FloatBind)
forall a. Maybe a
Nothing, MinorEnv
forall a. IntMap a
M.empty)
                                            Just min_defns :: MinorEnv
min_defns -> Int -> MinorEnv -> (MinorEnv, Maybe (Bag FloatBind), MinorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
minor MinorEnv
min_defns
    here_min :: Bag FloatBind
here_min = Maybe (Bag FloatBind)
mb_here_min Maybe (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a. Maybe a -> a -> a
`orElse` Bag FloatBind
forall a. Bag a
emptyBag
    (here_ceil :: Bag FloatBind
here_ceil, ceils' :: Bag FloatBind
ceils') | LevelType
typ LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = (Bag FloatBind
ceils, Bag FloatBind
forall a. Bag a
emptyBag)
                        | Bool
otherwise          = (Bag FloatBind
forall a. Bag a
emptyBag, Bag FloatBind
ceils)

-- Like partitionByLevel, but instead split out the bindings that are marked
-- to float to the nearest join ceiling (see Note [Join points])
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defs :: MajorEnv
defs)
  = (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
defs, Bag FloatBind
ceils)

-- Perform some action at a join ceiling, i.e., don't let join points float out
-- (see Note [Join points])
atJoinCeiling :: (FloatStats, FloatBinds, CoreExpr)
              -> (FloatStats, FloatBinds, CoreExpr)
atJoinCeiling :: (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling (fs :: FloatStats
fs, floats :: FloatBinds
floats, expr' :: Expr CoreBndr
expr')
  = (FloatStats
fs, FloatBinds
floats', Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install Bag FloatBind
ceils Expr CoreBndr
expr')
  where
    (floats' :: FloatBinds
floats', ceils :: Bag FloatBind
ceils) = FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling FloatBinds
floats

wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
wrapTick :: Tickish CoreBndr -> FloatBinds -> FloatBinds
wrapTick t :: Tickish CoreBndr
t (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defns :: MajorEnv
defns)
  = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB ((CoreBind -> CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag CoreBind -> CoreBind
wrap_bind Bag CoreBind
tops) (Bag FloatBind -> Bag FloatBind
wrap_defns Bag FloatBind
ceils)
       ((MinorEnv -> MinorEnv) -> MajorEnv -> MajorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map ((Bag FloatBind -> Bag FloatBind) -> MinorEnv -> MinorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map Bag FloatBind -> Bag FloatBind
wrap_defns) MajorEnv
defns)
  where
    wrap_defns :: Bag FloatBind -> Bag FloatBind
wrap_defns = (FloatBind -> FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag FloatBind -> FloatBind
wrap_one

    wrap_bind :: CoreBind -> CoreBind
wrap_bind (NonRec binder :: CoreBndr
binder rhs :: Expr CoreBndr
rhs) = CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
binder (Expr CoreBndr -> Expr CoreBndr
maybe_tick Expr CoreBndr
rhs)
    wrap_bind (Rec pairs :: [(CoreBndr, Expr CoreBndr)]
pairs)         = [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr CoreBndr -> Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall b c a. (b -> c) -> [(a, b)] -> [(a, c)]
mapSnd Expr CoreBndr -> Expr CoreBndr
maybe_tick [(CoreBndr, Expr CoreBndr)]
pairs)

    wrap_one :: FloatBind -> FloatBind
wrap_one (FloatLet bind :: CoreBind
bind)      = CoreBind -> FloatBind
FloatLet (CoreBind -> CoreBind
wrap_bind CoreBind
bind)
    wrap_one (FloatCase e :: Expr CoreBndr
e b :: CoreBndr
b c :: AltCon
c bs :: [CoreBndr]
bs) = Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBind
FloatCase (Expr CoreBndr -> Expr CoreBndr
maybe_tick Expr CoreBndr
e) CoreBndr
b AltCon
c [CoreBndr]
bs

    maybe_tick :: Expr CoreBndr -> Expr CoreBndr
maybe_tick e :: Expr CoreBndr
e | Expr CoreBndr -> Bool
exprIsHNF Expr CoreBndr
e = Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
tickHNFArgs Tickish CoreBndr
t Expr CoreBndr
e
                 | Bool
otherwise   = Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
mkTick Tickish CoreBndr
t Expr CoreBndr
e
      -- we don't need to wrap a tick around an HNF when we float it
      -- outside a tick: that is an invariant of the tick semantics
      -- Conversely, inlining of HNFs inside an SCC is allowed, and
      -- indeed the HNF we're floating here might well be inlined back
      -- again, and we don't want to end up with duplicate ticks.