{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM'
-- monad.
module StgLiftLams.LiftM (
    decomposeStgBinding, mkStgBinding,
    Env (..),
    -- * #floats# Handling floats
    -- $floats
    FloatLang (..), collectFloats, -- Exported just for the docs
    -- * Transformation monad
    LiftM, runLiftM, withCaffyness,
    -- ** Adding bindings
    startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
    -- ** Substitution and binders
    withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
    -- ** Occurrences
    substOcc, isLifted, formerFreeVars, liftedIdsExpander
  ) where

#include "HsVersions.h"

import GhcPrelude

import BasicTypes
import CostCentre ( isCurrentCCS, dontCareCCS )
import DynFlags
import FastString
import Id
import IdInfo
import Name
import Outputable
import OrdList
import StgSubst
import StgSyn
import Type
import UniqSupply
import Util
import VarEnv
import VarSet

import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )

-- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec [(BinderP pass, GenStgRhs pass)]
pairs) = (RecFlag
Recursive, [(BinderP pass, GenStgRhs pass)]
pairs)
decomposeStgBinding (StgNonRec BinderP pass
bndr GenStgRhs pass
rhs) = (RecFlag
NonRecursive, [(BinderP pass
bndr, GenStgRhs pass
rhs)])

mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
Recursive = [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec
mkStgBinding RecFlag
NonRecursive = (BinderP pass -> GenStgRhs pass -> GenStgBinding pass)
-> (BinderP pass, GenStgRhs pass) -> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry BinderP pass -> GenStgRhs pass -> GenStgBinding pass
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec ((BinderP pass, GenStgRhs pass) -> GenStgBinding pass)
-> ([(BinderP pass, GenStgRhs pass)]
    -> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(BinderP pass, GenStgRhs pass)] -> (BinderP pass, GenStgRhs pass)
forall a. [a] -> a
head

-- | Environment threaded around in a scoped, @Reader@-like fashion.
data Env
  = Env
  { Env -> DynFlags
e_dflags     :: !DynFlags
  -- ^ Read-only.
  , Env -> Subst
e_subst      :: !Subst
  -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId',
  -- because shadowing might make a closure's free variables unavailable at its
  -- call sites. Consider:
  -- @
  --    let f y = x + y in let x = 4 in f x
  -- @
  -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't
  -- available at its call site.
  , Env -> IdEnv DIdSet
e_expansions :: !(IdEnv DIdSet)
  -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because
  -- they are bound at the top-level. Every occurrence must supply the formerly
  -- free variables of the lifted 'Id', so they in turn become free variables of
  -- the call sites. This environment tracks this expansion from lifted 'Id's to
  -- their free variables.
  --
  -- 'InId's to 'OutId's.
  --
  -- Invariant: 'Id's not present in this map won't be substituted.
  , Env -> Bool
e_in_caffy_context :: !Bool
  -- ^ Are we currently analysing within a caffy context (e.g. the containing
  -- top-level binder's 'idCafInfo' is 'MayHaveCafRefs')? If not, we can safely
  -- assume that functions we lift out aren't caffy either.
  }

emptyEnv :: DynFlags -> Env
emptyEnv :: DynFlags -> Env
emptyEnv DynFlags
dflags = DynFlags -> Subst -> IdEnv DIdSet -> Bool -> Env
Env DynFlags
dflags Subst
emptySubst IdEnv DIdSet
forall a. VarEnv a
emptyVarEnv Bool
False


-- Note [Handling floats]
-- ~~~~~~~~~~~~~~~~~~~~~~
-- $floats
-- Consider the following expression:
--
-- @
--     f x =
--       let g y = ... f y ...
--       in g x
-- @
--
-- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@
-- binding above the binding for @f@:
--
-- @
--     g f y = ... f y ...
--     f x = g f x
-- @
--
-- But this very unnecessarily turns a known call to @f@ into an unknown one, in
-- addition to complicating matters for the analysis.
-- Instead, we'd really like to put both functions in the same recursive group,
-- thereby preserving the known call:
--
-- @
--     Rec {
--       g y = ... f y ...
--       f x = g x
--     }
-- @
--
-- But we don't want this to happen for just /any/ binding. That would create
-- possibly huge recursive groups in the process, calling for an occurrence
-- analyser on STG.
-- So, we need to track when we lift a binding out of a recursive RHS and add
-- the binding to the same recursive group as the enclosing recursive binding
-- (which must have either already been at the top-level or decided to be
-- lifted itself in order to preserve the known call).
--
-- This is done by expressing this kind of nesting structure as a 'Writer' over
-- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to
-- 'collectFloats'.
-- API-wise, the analysis will not need to know about the whole 'FloatLang'
-- business and will just manipulate it indirectly through actions in 'LiftM'.

-- | We need to detect when we are lifting something out of the RHS of a
-- recursive binding (c.f. "StgLiftLams.LiftM#floats"), in which case that
-- binding needs to be added to the same top-level recursive group. This
-- requires we detect a certain nesting structure, which is encoded by
-- 'StartBindingGroup' and 'EndBindingGroup'.
--
-- Although 'collectFloats' will only ever care if the current binding to be
-- lifted (through 'LiftedBinding') will occur inside such a binding group or
-- not, e.g. doesn't care about the nesting level as long as its greater than 0.
data FloatLang
  = StartBindingGroup
  | EndBindingGroup
  | PlainTopBinding OutStgTopBinding
  | LiftedBinding OutStgBinding

instance Outputable FloatLang where
  ppr :: FloatLang -> SDoc
ppr FloatLang
StartBindingGroup = Char -> SDoc
char Char
'('
  ppr FloatLang
EndBindingGroup = Char -> SDoc
char Char
')'
  ppr (PlainTopBinding StgTopStringLit{}) = String -> SDoc
text String
"<str>"
  ppr (PlainTopBinding (StgTopLifted GenStgBinding 'Vanilla
b)) = FloatLang -> SDoc
forall a. Outputable a => a -> SDoc
ppr (GenStgBinding 'Vanilla -> FloatLang
LiftedBinding GenStgBinding 'Vanilla
b)
  ppr (LiftedBinding GenStgBinding 'Vanilla
bind) = (if RecFlag -> Bool
isRec RecFlag
rec then Char -> SDoc
char Char
'r' else Char -> SDoc
char Char
'n') SDoc -> SDoc -> SDoc
<+> [Id] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
pairs)
    where
      (RecFlag
rec, [(Id, GenStgRhs 'Vanilla)]
pairs) = GenStgBinding 'Vanilla
-> (RecFlag, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind

-- | Flattens an expression in @['FloatLang']@ into an STG program, see #floats.
-- Important pre-conditions: The nesting of opening 'StartBindinGroup's and
-- closing 'EndBindinGroup's is balanced. Also, it is crucial that every binding
-- group has at least one recursive binding inside. Otherwise there's no point
-- in announcing the binding group in the first place and an @ASSERT@ will
-- trigger.
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats :: [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats = Int
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
forall a.
(Eq a, Num a) =>
a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (Int
0 :: Int) []
  where
    go :: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [] = []
    go a
_ [GenStgBinding 'Vanilla]
_ [] = String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"unterminated group")
    go a
n [GenStgBinding 'Vanilla]
binds (FloatLang
f:[FloatLang]
rest) = case FloatLang
f of
      FloatLang
StartBindingGroup -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
+a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
      FloatLang
EndBindingGroup
        | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"no group to end")
        | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted ([GenStgBinding 'Vanilla] -> GenStgBinding 'Vanilla
forall (t :: * -> *) (pass :: StgPass).
Foldable t =>
t (GenStgBinding pass) -> GenStgBinding pass
merge_binds [GenStgBinding 'Vanilla]
binds) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [FloatLang]
rest
        | Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
-a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
      PlainTopBinding GenStgTopBinding 'Vanilla
top_bind
        | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgTopBinding 'Vanilla
top_bind GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
        | Bool
otherwise -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"plain top binding inside group")
      LiftedBinding GenStgBinding 'Vanilla
bind
        | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> GenStgBinding 'Vanilla
forall (pass :: StgPass). GenStgBinding pass -> GenStgBinding pass
rm_cccs GenStgBinding 'Vanilla
bind) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
        | Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n (GenStgBinding 'Vanilla
bindGenStgBinding 'Vanilla
-> [GenStgBinding 'Vanilla] -> [GenStgBinding 'Vanilla]
forall a. a -> [a] -> [a]
:[GenStgBinding 'Vanilla]
binds) [FloatLang]
rest

    map_rhss :: (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
f = (RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass)
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding ((RecFlag, [(BinderP pass, GenStgRhs pass)]) -> GenStgBinding pass)
-> (GenStgBinding pass
    -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(BinderP pass, GenStgRhs pass)]
 -> [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (((BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a -> b) -> [a] -> [b]
map ((GenStgRhs pass -> GenStgRhs pass)
-> (BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second GenStgRhs pass -> GenStgRhs pass
f)) ((RecFlag, [(BinderP pass, GenStgRhs pass)])
 -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass
    -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding
    rm_cccs :: GenStgBinding pass -> GenStgBinding pass
rm_cccs = (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
forall (pass :: StgPass) (pass :: StgPass).
(BinderP pass ~ BinderP pass) =>
(GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS
    merge_binds :: t (GenStgBinding pass) -> GenStgBinding pass
merge_binds t (GenStgBinding pass)
binds = ASSERT( any is_rec binds )
                        [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ((GenStgBinding pass -> [(BinderP pass, GenStgRhs pass)])
-> t (GenStgBinding pass) -> [(BinderP pass, GenStgRhs pass)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a, b) -> b
snd ((RecFlag, [(BinderP pass, GenStgRhs pass)])
 -> [(BinderP pass, GenStgRhs pass)])
-> (GenStgBinding pass
    -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> [(BinderP pass, GenStgRhs pass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass -> GenStgBinding pass)
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> GenStgBinding pass
forall (pass :: StgPass). GenStgBinding pass -> GenStgBinding pass
rm_cccs) t (GenStgBinding pass)
binds)
    is_rec :: GenStgBinding pass -> Bool
is_rec StgRec{} = Bool
True
    is_rec GenStgBinding pass
_ = Bool
False

-- | Omitting this makes for strange closure allocation schemes that crash the
-- GC.
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure XRhsClosure pass
ext CostCentreStack
ccs UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body)
  | CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
  = XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure pass
ext CostCentreStack
dontCareCCS UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body
removeRhsCCCS (StgRhsCon CostCentreStack
ccs DataCon
con [StgArg]
args)
  | CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
  = CostCentreStack -> DataCon -> [StgArg] -> GenStgRhs pass
forall (pass :: StgPass).
CostCentreStack -> DataCon -> [StgArg] -> GenStgRhs pass
StgRhsCon CostCentreStack
dontCareCCS DataCon
con [StgArg]
args
removeRhsCCCS GenStgRhs pass
rhs = GenStgRhs pass
rhs

-- | The analysis monad consists of the following 'RWST' components:
--
--     * 'Env': Reader-like context. Contains a substitution, info about how
--       how lifted identifiers are to be expanded into applications and details
--       such as 'DynFlags' and a flag helping with determining if a lifted
--       binding is caffy.
--
--     * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program.
--
--     * No pure state component
--
--     * But wrapping around 'UniqSM' for generating fresh lifted binders.
--       (The @uniqAway@ approach could give the same name to two different
--       lifted binders, so this is necessary.)
newtype LiftM a
  = LiftM { LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
  deriving (a -> LiftM b -> LiftM a
(a -> b) -> LiftM a -> LiftM b
(forall a b. (a -> b) -> LiftM a -> LiftM b)
-> (forall a b. a -> LiftM b -> LiftM a) -> Functor LiftM
forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> LiftM b -> LiftM a
$c<$ :: forall a b. a -> LiftM b -> LiftM a
fmap :: (a -> b) -> LiftM a -> LiftM b
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
Functor, Functor LiftM
a -> LiftM a
Functor LiftM
-> (forall a. a -> LiftM a)
-> (forall a b. LiftM (a -> b) -> LiftM a -> LiftM b)
-> (forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM a)
-> Applicative LiftM
LiftM a -> LiftM b -> LiftM b
LiftM a -> LiftM b -> LiftM a
LiftM (a -> b) -> LiftM a -> LiftM b
(a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: LiftM a -> LiftM b -> LiftM a
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
*> :: LiftM a -> LiftM b -> LiftM b
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
liftA2 :: (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
<*> :: LiftM (a -> b) -> LiftM a -> LiftM b
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
pure :: a -> LiftM a
$cpure :: forall a. a -> LiftM a
$cp1Applicative :: Functor LiftM
Applicative, Applicative LiftM
a -> LiftM a
Applicative LiftM
-> (forall a b. LiftM a -> (a -> LiftM b) -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a. a -> LiftM a)
-> Monad LiftM
LiftM a -> (a -> LiftM b) -> LiftM b
LiftM a -> LiftM b -> LiftM b
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> LiftM a
$creturn :: forall a. a -> LiftM a
>> :: LiftM a -> LiftM b -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>>= :: LiftM a -> (a -> LiftM b) -> LiftM b
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$cp1Monad :: Applicative LiftM
Monad)

instance HasDynFlags LiftM where
  getDynFlags :: LiftM DynFlags
getDynFlags = RWST Env (OrdList FloatLang) () UniqSM DynFlags -> LiftM DynFlags
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> DynFlags)
-> RWST Env (OrdList FloatLang) () UniqSM DynFlags
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> DynFlags
e_dflags)

instance MonadUnique LiftM where
  getUniqueSupplyM :: LiftM UniqSupply
getUniqueSupplyM = RWST Env (OrdList FloatLang) () UniqSM UniqSupply
-> LiftM UniqSupply
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM UniqSupply
-> RWST Env (OrdList FloatLang) () UniqSM UniqSupply
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM UniqSupply
forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM)
  getUniqueM :: LiftM Unique
getUniqueM = RWST Env (OrdList FloatLang) () UniqSM Unique -> LiftM Unique
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM Unique -> RWST Env (OrdList FloatLang) () UniqSM Unique
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM)
  getUniquesM :: LiftM [Unique]
getUniquesM = RWST Env (OrdList FloatLang) () UniqSM [Unique] -> LiftM [Unique]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM [Unique] -> RWST Env (OrdList FloatLang) () UniqSM [Unique]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM [Unique]
forall (m :: * -> *). MonadUnique m => m [Unique]
getUniquesM)

runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [GenStgTopBinding 'Vanilla]
runLiftM DynFlags
dflags UniqSupply
us (LiftM RWST Env (OrdList FloatLang) () UniqSM ()
m) = [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats (OrdList FloatLang -> [FloatLang]
forall a. OrdList a -> [a]
fromOL OrdList FloatLang
floats)
  where
    (()
_, ()
_, OrdList FloatLang
floats) = UniqSupply
-> UniqSM ((), (), OrdList FloatLang)
-> ((), (), OrdList FloatLang)
forall a. UniqSupply -> UniqSM a -> a
initUs_ UniqSupply
us (RWST Env (OrdList FloatLang) () UniqSM ()
-> Env -> () -> UniqSM ((), (), OrdList FloatLang)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST Env (OrdList FloatLang) () UniqSM ()
m (DynFlags -> Env
emptyEnv DynFlags
dflags) ())

-- | Assumes a given caffyness for the execution of the passed action, which
-- influences the 'cafInfo' of lifted bindings.
withCaffyness :: Bool -> LiftM a -> LiftM a
withCaffyness :: Bool -> LiftM a -> LiftM a
withCaffyness Bool
caffy LiftM a
action
  = RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local (\Env
e -> Env
e { e_in_caffy_context :: Bool
e_in_caffy_context = Bool
caffy }) (LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM LiftM a
action))

-- | Writes a plain 'StgTopStringLit' to the output.
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit :: Id -> ByteString -> LiftM ()
addTopStringLit Id
id = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (ByteString -> RWST Env (OrdList FloatLang) () UniqSM ())
-> ByteString
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (ByteString -> OrdList FloatLang)
-> ByteString
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (ByteString -> FloatLang) -> ByteString -> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding 'Vanilla -> FloatLang
PlainTopBinding (GenStgTopBinding 'Vanilla -> FloatLang)
-> (ByteString -> GenStgTopBinding 'Vanilla)
-> ByteString
-> FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> ByteString -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id

-- | Starts a recursive binding group. See #floats# and 'collectFloats'.
startBindingGroup :: LiftM ()
startBindingGroup :: LiftM ()
startBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
StartBindingGroup

-- | Ends a recursive binding group. See #floats# and 'collectFloats'.
endBindingGroup :: LiftM ()
endBindingGroup :: LiftM ()
endBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
EndBindingGroup

-- | Lifts a binding to top-level. Depending on whether it's declared inside
-- a recursive RHS (see #floats# and 'collectFloats'), this might be added to
-- an existing recursive top-level binding group.
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding :: GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (GenStgBinding 'Vanilla
    -> RWST Env (OrdList FloatLang) () UniqSM ())
-> GenStgBinding 'Vanilla
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (GenStgBinding 'Vanilla -> OrdList FloatLang)
-> GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (GenStgBinding 'Vanilla -> FloatLang)
-> GenStgBinding 'Vanilla
-> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'Vanilla -> FloatLang
LiftedBinding

-- | Takes a binder and a continuation which is called with the substituted
-- binder. The continuation will be evaluated in a 'LiftM' context in which that
-- binder is deemed in scope. Think of it as a 'RWS.local' computation: After
-- the continuation finishes, the new binding won't be in scope anymore.
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr Id -> LiftM a
inner = RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ do
  Subst
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
  let (Id
bndr', Subst
subst') = Id -> Subst -> (Id, Subst)
substBndr Id
bndr Subst
subst
  (Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local (\Env
e -> Env
e { e_subst :: Subst
e_subst = Subst
subst' }) (LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))

-- | See 'withSubstBndr'.
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs :: f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall k (r :: k) (m :: k -> *) a. ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> (Id -> LiftM a) -> LiftM a
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr)

-- | Similarly to 'withSubstBndr', this function takes a set of variables to
-- abstract over, the binder to lift (and generate a fresh, substituted name
-- for) and a continuation in which that fresh, lifted binder is in scope.
--
-- It takes care of all the details involved with copying and adjusting the
-- binder, fresh name generation and caffyness.
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids Id
bndr Id -> LiftM a
inner = do
  Unique
uniq <- LiftM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
  let str :: String
str = String
"$l" String -> String -> String
forall a. [a] -> [a] -> [a]
++ OccName -> String
occNameString (Id -> OccName
forall a. NamedThing a => a -> OccName
getOccName Id
bndr)
  let ty :: Type
ty = [Id] -> Type -> Type
mkLamTypes (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids) (Id -> Type
idType Id
bndr)
  -- When the enclosing top-level binding is not caffy, then the lifted
  -- binding will not be caffy either. If we don't recognize this, non-caffy
  -- things call caffy things and then codegen screws up.
  Bool
in_caffy_ctxt <- RWST Env (OrdList FloatLang) () UniqSM Bool -> LiftM Bool
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Bool) -> RWST Env (OrdList FloatLang) () UniqSM Bool
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Bool
e_in_caffy_context)
  let caf_info :: CafInfo
caf_info = if Bool
in_caffy_ctxt then CafInfo
MayHaveCafRefs else CafInfo
NoCafRefs
  let bndr' :: Id
bndr'
        -- See Note [transferPolyIdInfo] in Id.hs. We need to do this at least
        -- for arity information.
        = Id -> [Id] -> Id -> Id
transferPolyIdInfo Id
bndr (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids)
        -- Otherwise we confuse code gen if bndr was not caffy: the new bndr is
        -- assumed to be caffy and will need an SRT. Transitive call sites might
        -- not be caffy themselves and subsequently will miss a static link
        -- field in their closure. Chaos ensues.
        (Id -> Id) -> (Type -> Id) -> Type -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> CafInfo -> Id) -> CafInfo -> Id -> Id
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> CafInfo -> Id
setIdCafInfo CafInfo
caf_info
        (Id -> Id) -> (Type -> Id) -> Type -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FastString -> Unique -> Type -> Id
mkSysLocalOrCoVar (String -> FastString
mkFastString String
str) Unique
uniq
        (Type -> Id) -> Type -> Id
forall a b. (a -> b) -> a -> b
$ Type
ty
  RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ (Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local
    (\Env
e -> Env
e
      { e_subst :: Subst
e_subst = Id -> Id -> Subst -> Subst
extendSubst Id
bndr Id
bndr' (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ Id -> Subst -> Subst
extendInScope Id
bndr' (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ Env -> Subst
e_subst Env
e
      , e_expansions :: IdEnv DIdSet
e_expansions = IdEnv DIdSet -> Id -> DIdSet -> IdEnv DIdSet
forall a. VarEnv a -> Id -> a -> VarEnv a
extendVarEnv (Env -> IdEnv DIdSet
e_expansions Env
e) Id
bndr DIdSet
abs_ids
      })
    (LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))

-- | See 'withLiftedBndr'.
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs :: DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall k (r :: k) (m :: k -> *) a. ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids)

-- | Substitutes a binder /occurrence/, which was brought in scope earlier by
-- 'withSubstBndr'\/'withLiftedBndr'.
substOcc :: Id -> LiftM Id
substOcc :: Id -> LiftM Id
substOcc Id
id = RWST Env (OrdList FloatLang) () UniqSM Id -> LiftM Id
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Id) -> RWST Env (OrdList FloatLang) () UniqSM Id
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
lookupIdSubst Id
id (Subst -> Id) -> (Env -> Subst) -> Env -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Subst
e_subst))

-- | Whether the given binding was decided to be lambda lifted.
isLifted :: InId -> LiftM Bool
isLifted :: Id -> LiftM Bool
isLifted Id
bndr = RWST Env (OrdList FloatLang) () UniqSM Bool -> LiftM Bool
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Bool) -> RWST Env (OrdList FloatLang) () UniqSM Bool
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (Id -> IdEnv DIdSet -> Bool
forall a. Id -> VarEnv a -> Bool
elemVarEnv Id
bndr (IdEnv DIdSet -> Bool) -> (Env -> IdEnv DIdSet) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> IdEnv DIdSet
e_expansions))

-- | Returns an empty list for a binding that was not lifted and the list of all
-- local variables the binding abstracts over (so, exactly the additional
-- arguments at adjusted call sites) otherwise.
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars :: Id -> LiftM [Id]
formerFreeVars Id
f = RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id])
-> RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a b. (a -> b) -> a -> b
$ do
  IdEnv DIdSet
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
  [Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id])
-> [Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id]
forall a b. (a -> b) -> a -> b
$ case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
f of
    Maybe DIdSet
Nothing -> []
    Just DIdSet
fvs -> DIdSet -> [Id]
dVarSetElems DIdSet
fvs

-- | Creates an /expander function/ for the current set of lifted binders.
-- This expander function will replace any 'InId' by their corresponding 'OutId'
-- and, in addition, will expand any lifted binders by the former free variables
-- it abstracts over.
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
 -> LiftM (DIdSet -> DIdSet))
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a b. (a -> b) -> a -> b
$ do
  IdEnv DIdSet
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
  Subst
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
  -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope"
  -- warnings generated by 'lookupIdSubst' due to local bindings within RHS.
  -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in
  -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much
  -- trouble.
  let go :: DIdSet -> Id -> DIdSet
go DIdSet
set Id
fv = case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
fv of
        Maybe DIdSet
Nothing -> DIdSet -> Id -> DIdSet
extendDVarSet DIdSet
set (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
noWarnLookupIdSubst Id
fv Subst
subst) -- Not lifted
        Just DIdSet
fvs' -> DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
set DIdSet
fvs'
  let expander :: DIdSet -> DIdSet
expander DIdSet
fvs = (DIdSet -> Id -> DIdSet) -> DIdSet -> [Id] -> DIdSet
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DIdSet -> Id -> DIdSet
go DIdSet
emptyDVarSet (DIdSet -> [Id]
dVarSetElems DIdSet
fvs)
  (DIdSet -> DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DIdSet -> DIdSet
expander