{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM'
-- monad.
module GHC.Stg.Lift.Monad (
    decomposeStgBinding, mkStgBinding,
    Env (..),
    -- * #floats# Handling floats
    -- $floats
    FloatLang (..), collectFloats, -- Exported just for the docs
    -- * Transformation monad
    LiftM, runLiftM,
    -- ** Adding bindings
    startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
    -- ** Substitution and binders
    withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
    -- ** Occurrences
    substOcc, isLifted, formerFreeVars, liftedIdsExpander
  ) where

#include "HsVersions.h"

import GHC.Prelude

import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Driver.Session
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Utils.Panic
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity

import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )

-- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding :: forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec [(BinderP pass, GenStgRhs pass)]
pairs) = (RecFlag
Recursive, [(BinderP pass, GenStgRhs pass)]
pairs)
decomposeStgBinding (StgNonRec BinderP pass
bndr GenStgRhs pass
rhs) = (RecFlag
NonRecursive, [(BinderP pass
bndr, GenStgRhs pass
rhs)])

mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding :: forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
Recursive = forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec
mkStgBinding RecFlag
NonRecursive = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head

-- | Environment threaded around in a scoped, @Reader@-like fashion.
data Env
  = Env
  { Env -> DynFlags
e_dflags     :: !DynFlags
  -- ^ Read-only.
  , Env -> Subst
e_subst      :: !Subst
  -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId',
  -- because shadowing might make a closure's free variables unavailable at its
  -- call sites. Consider:
  -- @
  --    let f y = x + y in let x = 4 in f x
  -- @
  -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't
  -- available at its call site.
  , Env -> IdEnv DIdSet
e_expansions :: !(IdEnv DIdSet)
  -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because
  -- they are bound at the top-level. Every occurrence must supply the formerly
  -- free variables of the lifted 'Id', so they in turn become free variables of
  -- the call sites. This environment tracks this expansion from lifted 'Id's to
  -- their free variables.
  --
  -- 'InId's to 'OutId's.
  --
  -- Invariant: 'Id's not present in this map won't be substituted.
  }

emptyEnv :: DynFlags -> Env
emptyEnv :: DynFlags -> Env
emptyEnv DynFlags
dflags = DynFlags -> Subst -> IdEnv DIdSet -> Env
Env DynFlags
dflags Subst
emptySubst forall a. VarEnv a
emptyVarEnv


-- Note [Handling floats]
-- ~~~~~~~~~~~~~~~~~~~~~~
-- $floats
-- Consider the following expression:
--
-- @
--     f x =
--       let g y = ... f y ...
--       in g x
-- @
--
-- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@
-- binding above the binding for @f@:
--
-- @
--     g f y = ... f y ...
--     f x = g f x
-- @
--
-- But this very unnecessarily turns a known call to @f@ into an unknown one, in
-- addition to complicating matters for the analysis.
-- Instead, we'd really like to put both functions in the same recursive group,
-- thereby preserving the known call:
--
-- @
--     Rec {
--       g y = ... f y ...
--       f x = g x
--     }
-- @
--
-- But we don't want this to happen for just /any/ binding. That would create
-- possibly huge recursive groups in the process, calling for an occurrence
-- analyser on STG.
-- So, we need to track when we lift a binding out of a recursive RHS and add
-- the binding to the same recursive group as the enclosing recursive binding
-- (which must have either already been at the top-level or decided to be
-- lifted itself in order to preserve the known call).
--
-- This is done by expressing this kind of nesting structure as a 'Writer' over
-- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to
-- 'collectFloats'.
-- API-wise, the analysis will not need to know about the whole 'FloatLang'
-- business and will just manipulate it indirectly through actions in 'LiftM'.

-- | We need to detect when we are lifting something out of the RHS of a
-- recursive binding (c.f. "GHC.Stg.Lift.Monad#floats"), in which case that
-- binding needs to be added to the same top-level recursive group. This
-- requires we detect a certain nesting structure, which is encoded by
-- 'StartBindingGroup' and 'EndBindingGroup'.
--
-- Although 'collectFloats' will only ever care if the current binding to be
-- lifted (through 'LiftedBinding') will occur inside such a binding group or
-- not, e.g. doesn't care about the nesting level as long as its greater than 0.
data FloatLang
  = StartBindingGroup
  | EndBindingGroup
  | PlainTopBinding OutStgTopBinding
  | LiftedBinding OutStgBinding

instance Outputable FloatLang where
  ppr :: FloatLang -> SDoc
ppr FloatLang
StartBindingGroup = Char -> SDoc
char Char
'('
  ppr FloatLang
EndBindingGroup = Char -> SDoc
char Char
')'
  ppr (PlainTopBinding StgTopStringLit{}) = String -> SDoc
text String
"<str>"
  ppr (PlainTopBinding (StgTopLifted GenStgBinding 'Vanilla
b)) = forall a. Outputable a => a -> SDoc
ppr (GenStgBinding 'Vanilla -> FloatLang
LiftedBinding GenStgBinding 'Vanilla
b)
  ppr (LiftedBinding GenStgBinding 'Vanilla
bind) = (if RecFlag -> Bool
isRec RecFlag
rec then Char -> SDoc
char Char
'r' else Char -> SDoc
char Char
'n') SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)
    where
      (RecFlag
rec, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind

-- | Flattens an expression in @['FloatLang']@ into an STG program, see "GHC.Stg.Lift.Monad#floats".
-- Important pre-conditions: The nesting of opening 'StartBindinGroup's and
-- closing 'EndBindinGroup's is balanced. Also, it is crucial that every binding
-- group has at least one recursive binding inside. Otherwise there's no point
-- in announcing the binding group in the first place and an @ASSERT@ will
-- trigger.
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats :: [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats = forall {a}.
(Eq a, Num a) =>
a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (Int
0 :: Int) []
  where
    go :: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [] = []
    go a
_ [GenStgBinding 'Vanilla]
_ [] = forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"unterminated group")
    go a
n [GenStgBinding 'Vanilla]
binds (FloatLang
f:[FloatLang]
rest) = case FloatLang
f of
      FloatLang
StartBindingGroup -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
nforall a. Num a => a -> a -> a
+a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
      FloatLang
EndBindingGroup
        | a
n forall a. Eq a => a -> a -> Bool
== a
0 -> forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"no group to end")
        | a
n forall a. Eq a => a -> a -> Bool
== a
1 -> forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (forall {t :: * -> *} {pass :: StgPass}.
Foldable t =>
t (GenStgBinding pass) -> GenStgBinding pass
merge_binds [GenStgBinding 'Vanilla]
binds) forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [FloatLang]
rest
        | Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
nforall a. Num a => a -> a -> a
-a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
      PlainTopBinding GenStgTopBinding 'Vanilla
top_bind
        | a
n forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgTopBinding 'Vanilla
top_bind forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
        | Bool
otherwise -> forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"plain top binding inside group")
      LiftedBinding GenStgBinding 'Vanilla
bind
        | a
n forall a. Eq a => a -> a -> Bool
== a
0 -> forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs GenStgBinding 'Vanilla
bind) forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
        | Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n (GenStgBinding 'Vanilla
bindforall a. a -> [a] -> [a]
:[GenStgBinding 'Vanilla]
binds) [FloatLang]
rest

    map_rhss :: (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
f = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second GenStgRhs pass -> GenStgRhs pass
f)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding
    rm_cccs :: GenStgBinding pass -> GenStgBinding pass
rm_cccs = forall {pass :: StgPass} {pass :: StgPass}.
(BinderP pass ~ BinderP pass) =>
(GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS
    merge_binds :: t (GenStgBinding pass) -> GenStgBinding pass
merge_binds t (GenStgBinding pass)
binds = ASSERT( any is_rec binds )
                        forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs) t (GenStgBinding pass)
binds)
    is_rec :: GenStgBinding pass -> Bool
is_rec StgRec{} = Bool
True
    is_rec GenStgBinding pass
_ = Bool
False

-- | Omitting this makes for strange closure allocation schemes that crash the
-- GC.
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS :: forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure XRhsClosure pass
ext CostCentreStack
ccs UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body)
  | CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
  = forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure pass
ext CostCentreStack
dontCareCCS UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body
removeRhsCCCS (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args)
  | CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
  = forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
dontCareCCS DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args
removeRhsCCCS GenStgRhs pass
rhs = GenStgRhs pass
rhs

-- | The analysis monad consists of the following 'RWST' components:
--
--     * 'Env': Reader-like context. Contains a substitution, info about how
--       how lifted identifiers are to be expanded into applications and details
--       such as 'DynFlags'.
--
--     * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program.
--
--     * No pure state component
--
--     * But wrapping around 'UniqSM' for generating fresh lifted binders.
--       (The @uniqAway@ approach could give the same name to two different
--       lifted binders, so this is necessary.)
newtype LiftM a
  = LiftM { forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
  deriving (forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LiftM b -> LiftM a
$c<$ :: forall a b. a -> LiftM b -> LiftM a
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
Functor, Functor LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
pure :: forall a. a -> LiftM a
$cpure :: forall a. a -> LiftM a
Applicative, Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> LiftM a
$creturn :: forall a. a -> LiftM a
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
Monad)

instance HasDynFlags LiftM where
  getDynFlags :: LiftM DynFlags
getDynFlags = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> DynFlags
e_dflags)

instance MonadUnique LiftM where
  getUniqueSupplyM :: LiftM UniqSupply
getUniqueSupplyM = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM)
  getUniqueM :: LiftM Unique
getUniqueM = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM)
  getUniquesM :: LiftM [Unique]
getUniquesM = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadUnique m => m [Unique]
getUniquesM)

runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [GenStgTopBinding 'Vanilla]
runLiftM DynFlags
dflags UniqSupply
us (LiftM RWST Env (OrdList FloatLang) () UniqSM ()
m) = [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats (forall a. OrdList a -> [a]
fromOL OrdList FloatLang
floats)
  where
    (()
_, ()
_, OrdList FloatLang
floats) = forall a. UniqSupply -> UniqSM a -> a
initUs_ UniqSupply
us (forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST Env (OrdList FloatLang) () UniqSM ()
m (DynFlags -> Env
emptyEnv DynFlags
dflags) ())

-- | Writes a plain 'StgTopStringLit' to the output.
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit :: Id -> ByteString -> LiftM ()
addTopStringLit Id
id = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> OrdList a
unitOL forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding 'Vanilla -> FloatLang
PlainTopBinding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id

-- | Starts a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'.
startBindingGroup :: LiftM ()
startBindingGroup :: LiftM ()
startBindingGroup = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall a b. (a -> b) -> a -> b
$ forall a. a -> OrdList a
unitOL forall a b. (a -> b) -> a -> b
$ FloatLang
StartBindingGroup

-- | Ends a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'.
endBindingGroup :: LiftM ()
endBindingGroup :: LiftM ()
endBindingGroup = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall a b. (a -> b) -> a -> b
$ forall a. a -> OrdList a
unitOL forall a b. (a -> b) -> a -> b
$ FloatLang
EndBindingGroup

-- | Lifts a binding to top-level. Depending on whether it's declared inside
-- a recursive RHS (see "GHC.Stg.Lift.Monad#floats" and 'collectFloats'), this might be added to
-- an existing recursive top-level binding group.
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding :: GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> OrdList a
unitOL forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'Vanilla -> FloatLang
LiftedBinding

-- | Takes a binder and a continuation which is called with the substituted
-- binder. The continuation will be evaluated in a 'LiftM' context in which that
-- binder is deemed in scope. Think of it as a 'RWS.local' computation: After
-- the continuation finishes, the new binding won't be in scope anymore.
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr :: forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr Id -> LiftM a
inner = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ do
  Subst
subst <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
  let (Id
bndr', Subst
subst') = Id -> Subst -> (Id, Subst)
substBndr Id
bndr Subst
subst
  forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local (\Env
e -> Env
e { e_subst :: Subst
e_subst = Subst
subst' }) (forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))

-- | See 'withSubstBndr'.
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs :: forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr)

-- | Similarly to 'withSubstBndr', this function takes a set of variables to
-- abstract over, the binder to lift (and generate a fresh, substituted name
-- for) and a continuation in which that fresh, lifted binder is in scope.
--
-- It takes care of all the details involved with copying and adjusting the
-- binder and fresh name generation.
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr :: forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids Id
bndr Id -> LiftM a
inner = do
  Unique
uniq <- forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
  let str :: String
str = String
"$l" forall a. [a] -> [a] -> [a]
++ OccName -> String
occNameString (forall a. NamedThing a => a -> OccName
getOccName Id
bndr)
  let ty :: Type
ty = [Id] -> Type -> Type
mkLamTypes (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids) (Id -> Type
idType Id
bndr)
  let bndr' :: Id
bndr'
        -- See Note [transferPolyIdInfo] in GHC.Types.Id. We need to do this at least
        -- for arity information.
        = Id -> [Id] -> Id -> Id
transferPolyIdInfo Id
bndr (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. FastString -> Unique -> Type -> Type -> Id
mkSysLocal (String -> FastString
mkFastString String
str) Unique
uniq Type
Many
        forall a b. (a -> b) -> a -> b
$ Type
ty
  forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local
    (\Env
e -> Env
e
      { e_subst :: Subst
e_subst = Id -> Id -> Subst -> Subst
extendSubst Id
bndr Id
bndr' forall a b. (a -> b) -> a -> b
$ Id -> Subst -> Subst
extendInScope Id
bndr' forall a b. (a -> b) -> a -> b
$ Env -> Subst
e_subst Env
e
      , e_expansions :: IdEnv DIdSet
e_expansions = forall a. VarEnv a -> Id -> a -> VarEnv a
extendVarEnv (Env -> IdEnv DIdSet
e_expansions Env
e) Id
bndr DIdSet
abs_ids
      })
    (forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))

-- | See 'withLiftedBndr'.
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs :: forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids = forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids)

-- | Substitutes a binder /occurrence/, which was brought in scope earlier by
-- 'withSubstBndr' \/ 'withLiftedBndr'.
substOcc :: Id -> LiftM Id
substOcc :: Id -> LiftM Id
substOcc Id
id = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (HasCallStack => Id -> Subst -> Id
lookupIdSubst Id
id forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Subst
e_subst))

-- | Whether the given binding was decided to be lambda lifted.
isLifted :: InId -> LiftM Bool
isLifted :: Id -> LiftM Bool
isLifted Id
bndr = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (forall a. Id -> VarEnv a -> Bool
elemVarEnv Id
bndr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> IdEnv DIdSet
e_expansions))

-- | Returns an empty list for a binding that was not lifted and the list of all
-- local variables the binding abstracts over (so, exactly the additional
-- arguments at adjusted call sites) otherwise.
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars :: Id -> LiftM [Id]
formerFreeVars Id
f = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ do
  IdEnv DIdSet
expansions <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
f of
    Maybe DIdSet
Nothing -> []
    Just DIdSet
fvs -> DIdSet -> [Id]
dVarSetElems DIdSet
fvs

-- | Creates an /expander function/ for the current set of lifted binders.
-- This expander function will replace any 'InId' by their corresponding 'OutId'
-- and, in addition, will expand any lifted binders by the former free variables
-- it abstracts over.
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ do
  IdEnv DIdSet
expansions <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
  Subst
subst <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
  -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope"
  -- warnings generated by 'lookupIdSubst' due to local bindings within RHS.
  -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in
  -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much
  -- trouble.
  let go :: DIdSet -> Id -> DIdSet
go DIdSet
set Id
fv = case forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
fv of
        Maybe DIdSet
Nothing -> DIdSet -> Id -> DIdSet
extendDVarSet DIdSet
set (HasCallStack => Id -> Subst -> Id
noWarnLookupIdSubst Id
fv Subst
subst) -- Not lifted
        Just DIdSet
fvs' -> DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
set DIdSet
fvs'
  let expander :: DIdSet -> DIdSet
expander DIdSet
fvs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DIdSet -> Id -> DIdSet
go DIdSet
emptyDVarSet (DIdSet -> [Id]
dVarSetElems DIdSet
fvs)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure DIdSet -> DIdSet
expander