{-|
  Prepare the STG for bytecode generation:

   - Ensure that all breakpoints are directly under
        a let-binding, introducing a new binding for
        those that aren't already.

   - Protect Not-necessarily lifted join points, see
        Note [Not-necessarily-lifted join points]

 -}

module GHC.Stg.BcPrep ( bcPrep ) where

import GHC.Prelude

import GHC.Types.Id.Make
import GHC.Types.Id
import GHC.Core.Type
import GHC.Builtin.Types ( unboxedUnitTy )
import GHC.Builtin.Types.Prim
import GHC.Types.Unique
import GHC.Data.FastString
import GHC.Utils.Panic.Plain
import GHC.Types.Tickish
import GHC.Types.Unique.Supply
import qualified GHC.Types.CostCentre as CC
import GHC.Stg.Syntax
import GHC.Utils.Monad.State.Strict

data BcPrepM_State
   = BcPrepM_State
        { BcPrepM_State -> UniqSupply
prepUniqSupply :: !UniqSupply      -- for generating fresh variable names
        }

type BcPrepM a = State BcPrepM_State a

bcPrepRHS :: StgRhs -> BcPrepM StgRhs
-- explicitly match all constructors so we get a warning if we miss any
bcPrepRHS :: StgRhs -> BcPrepM StgRhs
bcPrepRHS (StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args (StgTick bp :: StgTickish
bp@Breakpoint{} GenStgExpr 'Vanilla
expr)) = do
  {- If we have a breakpoint directly under an StgRhsClosure we don't
     need to introduce a new binding for it.
   -}
  GenStgExpr 'Vanilla
expr' <- GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args (forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
bp GenStgExpr 'Vanilla
expr'))
bcPrepRHS (StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args GenStgExpr 'Vanilla
expr) =
  forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
bcPrepRHS con :: StgRhs
con@StgRhsCon{} = forall (f :: * -> *) a. Applicative f => a -> f a
pure StgRhs
con

bcPrepExpr :: StgExpr -> BcPrepM StgExpr
-- explicitly match all constructors so we get a warning if we miss any
bcPrepExpr :: GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr (StgTick bp :: StgTickish
bp@(Breakpoint XBreakpoint 'TickishPassStg
tick_ty Int
_ [XTickishId 'TickishPassStg]
_) GenStgExpr 'Vanilla
rhs)
  | Kind -> Bool
isLiftedTypeKind (HasDebugCallStack => Kind -> Kind
typeKind XBreakpoint 'TickishPassStg
tick_ty) = do
      Id
id <- Kind -> BcPrepM Id
newId XBreakpoint 'TickishPassStg
tick_ty
      GenStgExpr 'Vanilla
rhs' <- GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs
      let expr' :: GenStgExpr 'Vanilla
expr' = forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
bp GenStgExpr 'Vanilla
rhs'
          bnd :: GenStgBinding 'Vanilla
bnd = forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec Id
id (forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure NoExtFieldSilent
noExtFieldSilent
                                            CostCentreStack
CC.dontCareCCS
                                            UpdateFlag
ReEntrant
                                            []
                                            GenStgExpr 'Vanilla
expr'
                             )
          letExp :: GenStgExpr 'Vanilla
letExp = forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bnd (forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
letExp
  | Bool
otherwise = do
      Id
id <- Kind -> BcPrepM Id
newId (Kind -> Kind -> Kind
mkVisFunTyMany Kind
realWorldStatePrimTy XBreakpoint 'TickishPassStg
tick_ty)
      GenStgExpr 'Vanilla
rhs' <- GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs
      let expr' :: GenStgExpr 'Vanilla
expr' = forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
bp GenStgExpr 'Vanilla
rhs'
          bnd :: GenStgBinding 'Vanilla
bnd = forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec Id
id (forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure NoExtFieldSilent
noExtFieldSilent
                                            CostCentreStack
CC.dontCareCCS
                                            UpdateFlag
ReEntrant
                                            [Id
voidArgId]
                                            GenStgExpr 'Vanilla
expr'
                             )
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bnd (forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id [Id -> StgArg
StgVarArg Id
realWorldPrimId])
bcPrepExpr (StgTick StgTickish
tick GenStgExpr 'Vanilla
rhs) =
  forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
tick forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs
bcPrepExpr (StgLet XLet 'Vanilla
xlet GenStgBinding 'Vanilla
bnds GenStgExpr 'Vanilla
expr) =
  forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
xlet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgBinding 'Vanilla -> BcPrepM (GenStgBinding 'Vanilla)
bcPrepBind GenStgBinding 'Vanilla
bnds
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
bcPrepExpr (StgLetNoEscape XLetNoEscape 'Vanilla
xlne GenStgBinding 'Vanilla
bnds GenStgExpr 'Vanilla
expr) =
  forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLetNoEscape 'Vanilla
xlne forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgBinding 'Vanilla -> BcPrepM (GenStgBinding 'Vanilla)
bcPrepBind GenStgBinding 'Vanilla
bnds
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
bcPrepExpr (StgCase GenStgExpr 'Vanilla
expr BinderP 'Vanilla
bndr AltType
alt_type [GenStgAlt 'Vanilla]
alts) =
  forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure BinderP 'Vanilla
bndr
          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure AltType
alt_type
          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM GenStgAlt 'Vanilla -> BcPrepM (GenStgAlt 'Vanilla)
bcPrepAlt [GenStgAlt 'Vanilla]
alts
bcPrepExpr lit :: GenStgExpr 'Vanilla
lit@StgLit{} = forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
lit
-- See Note [Not-necessarily-lifted join points], step 3.
bcPrepExpr (StgApp Id
x [])
  | Id -> Bool
isNNLJoinPoint Id
x = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
      forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp (Id -> Id
protectNNLJoinPointId Id
x) [Id -> StgArg
StgVarArg Id
voidPrimId]
bcPrepExpr app :: GenStgExpr 'Vanilla
app@StgApp{} = forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
app
bcPrepExpr app :: GenStgExpr 'Vanilla
app@StgConApp{} = forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
app
bcPrepExpr app :: GenStgExpr 'Vanilla
app@StgOpApp{} = forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
app

bcPrepAlt :: StgAlt -> BcPrepM StgAlt
bcPrepAlt :: GenStgAlt 'Vanilla -> BcPrepM (GenStgAlt 'Vanilla)
bcPrepAlt (GenStgAlt AltCon
con [BinderP 'Vanilla]
bndrs GenStgExpr 'Vanilla
rhs) = forall (pass :: StgPass).
AltCon -> [BinderP pass] -> GenStgExpr pass -> GenStgAlt pass
GenStgAlt AltCon
con [BinderP 'Vanilla]
bndrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs

bcPrepBind :: StgBinding -> BcPrepM StgBinding
-- explicitly match all constructors so we get a warning if we miss any
bcPrepBind :: GenStgBinding 'Vanilla -> BcPrepM (GenStgBinding 'Vanilla)
bcPrepBind (StgNonRec BinderP 'Vanilla
bndr StgRhs
rhs) =
  let (Id
bndr', StgRhs
rhs') = (Id, StgRhs) -> (Id, StgRhs)
bcPrepSingleBind (BinderP 'Vanilla
bndr, StgRhs
rhs)
  in  forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec Id
bndr' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StgRhs -> BcPrepM StgRhs
bcPrepRHS StgRhs
rhs'
bcPrepBind (StgRec [(BinderP 'Vanilla, StgRhs)]
bnds) =
  forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((\(Id
b,StgRhs
r) -> (,) Id
b forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StgRhs -> BcPrepM StgRhs
bcPrepRHS StgRhs
r) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, StgRhs) -> (Id, StgRhs)
bcPrepSingleBind)
                  [(BinderP 'Vanilla, StgRhs)]
bnds

bcPrepSingleBind :: (Id, StgRhs) -> (Id, StgRhs)
-- If necessary, modify this Id and body to protect not-necessarily-lifted join points.
-- See Note [Not-necessarily-lifted join points], step 2.
bcPrepSingleBind :: (Id, StgRhs) -> (Id, StgRhs)
bcPrepSingleBind (Id
x, StgRhsClosure XRhsClosure 'Vanilla
ext CostCentreStack
cc UpdateFlag
upd_flag [BinderP 'Vanilla]
args GenStgExpr 'Vanilla
body)
  | Id -> Bool
isNNLJoinPoint Id
x
  = ( Id -> Id
protectNNLJoinPointId Id
x
    , forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
ext CostCentreStack
cc UpdateFlag
upd_flag ([BinderP 'Vanilla]
args forall a. [a] -> [a] -> [a]
++ [Id
voidArgId]) GenStgExpr 'Vanilla
body)
bcPrepSingleBind (Id, StgRhs)
bnd = (Id, StgRhs)
bnd

bcPrepTopLvl :: StgTopBinding -> BcPrepM StgTopBinding
bcPrepTopLvl :: StgTopBinding -> BcPrepM StgTopBinding
bcPrepTopLvl lit :: StgTopBinding
lit@StgTopStringLit{} = forall (f :: * -> *) a. Applicative f => a -> f a
pure StgTopBinding
lit
bcPrepTopLvl (StgTopLifted GenStgBinding 'Vanilla
bnd) = forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgBinding 'Vanilla -> BcPrepM (GenStgBinding 'Vanilla)
bcPrepBind GenStgBinding 'Vanilla
bnd

bcPrep :: UniqSupply -> [InStgTopBinding] -> [OutStgTopBinding]
bcPrep :: UniqSupply -> [StgTopBinding] -> [StgTopBinding]
bcPrep UniqSupply
us [StgTopBinding]
bnds = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM StgTopBinding -> BcPrepM StgTopBinding
bcPrepTopLvl [StgTopBinding]
bnds) (UniqSupply -> BcPrepM_State
BcPrepM_State UniqSupply
us)

-- Is this Id a not-necessarily-lifted join point?
-- See Note [Not-necessarily-lifted join points], step 1
isNNLJoinPoint :: Id -> Bool
isNNLJoinPoint :: Id -> Bool
isNNLJoinPoint Id
x = Id -> Bool
isJoinId Id
x Bool -> Bool -> Bool
&& Kind -> Bool
mightBeUnliftedType (Id -> Kind
idType Id
x)

-- Update an Id's type to take a Void# argument.
-- Precondition: the Id is a not-necessarily-lifted join point.
-- See Note [Not-necessarily-lifted join points]
protectNNLJoinPointId :: Id -> Id
protectNNLJoinPointId :: Id -> Id
protectNNLJoinPointId Id
x
  = forall a. HasCallStack => Bool -> a -> a
assert (Id -> Bool
isNNLJoinPoint Id
x )
    (Kind -> Kind) -> Id -> Id
updateIdTypeButNotMult (Kind
unboxedUnitTy Kind -> Kind -> Kind
`mkVisFunTyMany`) Id
x

newUnique :: BcPrepM Unique
newUnique :: BcPrepM Unique
newUnique = forall s a. (s -> (a, s)) -> State s a
state forall a b. (a -> b) -> a -> b
$
  \BcPrepM_State
st -> case UniqSupply -> (Unique, UniqSupply)
takeUniqFromSupply (BcPrepM_State -> UniqSupply
prepUniqSupply BcPrepM_State
st) of
            (Unique
uniq, UniqSupply
us) -> (Unique
uniq, BcPrepM_State
st { prepUniqSupply :: UniqSupply
prepUniqSupply = UniqSupply
us })

newId :: Type -> BcPrepM Id
newId :: Kind -> BcPrepM Id
newId Kind
ty = do
    Unique
uniq <- BcPrepM Unique
newUnique
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ FastString -> Unique -> Kind -> Kind -> Id
mkSysLocal FastString
prepFS Unique
uniq Kind
Many Kind
ty

prepFS :: FastString
prepFS :: FastString
prepFS = String -> FastString
fsLit String
"bcprep"

{-

Note [Not-necessarily-lifted join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A join point variable is essentially a goto-label: it is, for example,
never used as an argument to another function, and it is called only
in tail position. See Note [Join points] and Note [Invariants on join points],
both in GHC.Core. Because join points do not compile to true, red-blooded
variables (with, e.g., registers allocated to them), they are allowed
to be representation-polymorphic.
(See invariant #6 in Note [Invariants on join points] in GHC.Core.)

However, in this byte-code generator, join points *are* treated just as
ordinary variables. There is no check whether a binding is for a join point
or not; they are all treated uniformly. (Perhaps there is a missed optimization
opportunity here, but that is beyond the scope of my (Richard E's) Thursday.)

We thus must have *some* strategy for dealing with representation-polymorphic
and unlifted join points. Representation-polymorphic variables are generally
not allowed (though representation -polymorphic join points *are*; see
Note [Invariants on join points] in GHC.Core, point 6), and we don't wish to
evaluate unlifted join points eagerly.
The questionable join points are *not-necessarily-lifted join points*
(NNLJPs). (Not having such a strategy led to #16509, which panicked in the
isUnliftedType check in the AnnVar case of schemeE.) Here is the strategy:

1. Detect NNLJPs. This is done in isNNLJoinPoint.

2. When binding an NNLJP, add a `\ (_ :: (# #)) ->` to its RHS, and modify the
   type to tack on a `(# #) ->`.
   Note that functions are never representation-polymorphic, so this
   transformation changes an NNLJP to a non-representation-polymorphic
   join point. This is done in bcPrepSingleBind.

3. At an occurrence of an NNLJP, add an application to void# (called voidPrimId),
   being careful to note the new type of the NNLJP. This is done in the AnnVar
   case of schemeE, with help from protectNNLJoinPointId.

Here is an example. Suppose we have

  f = \(r :: RuntimeRep) (a :: TYPE r) (x :: T).
      join j :: a
           j = error @r @a "bloop"
      in case x of
           A -> j
           B -> j
           C -> error @r @a "blurp"

Our plan is to behave is if the code was

  f = \(r :: RuntimeRep) (a :: TYPE r) (x :: T).
      let j :: (Void# -> a)
          j = \ _ -> error @r @a "bloop"
      in case x of
           A -> j void#
           B -> j void#
           C -> error @r @a "blurp"

It's a bit hacky, but it works well in practice and is local. I suspect the
Right Fix is to take advantage of join points as goto-labels.

-}