{-# LANGUAGE CPP #-}

-----------------------------------------------------------------------------
--
-- Code generator utilities; mostly monadic
--
-- (c) The University of Glasgow 2004-2006
--
-----------------------------------------------------------------------------

module GHC.StgToCmm.TagCheck
  ( emitTagAssertion, emitArgTagCheck, checkArg, whenCheckTags,
    checkArgStatic, checkFunctionArgTags,checkConArgsStatic,checkConArgsDyn) where

#include "ClosureTypes.h"

import GHC.Prelude

import GHC.StgToCmm.Env
import GHC.StgToCmm.Monad
import GHC.StgToCmm.Utils
import GHC.Cmm
import GHC.Cmm.BlockId
import GHC.Cmm.Graph as CmmGraph

import GHC.Core.Type
import GHC.Types.Id
import GHC.Utils.Misc
import GHC.Utils.Outputable

import GHC.Core.DataCon
import Control.Monad
import GHC.StgToCmm.Types
import GHC.Utils.Panic (pprPanic)
import GHC.Utils.Panic.Plain (panic)
import GHC.Stg.Syntax
import GHC.StgToCmm.Closure
import GHC.Cmm.Switch (mkSwitchTargets)
import GHC.Cmm.Info (cmmGetClosureType)
import GHC.Types.RepType (dataConRuntimeRepStrictness)
import GHC.Types.Basic
import GHC.Data.FastString (mkFastString)

import qualified Data.Map as M

-- | Check all arguments marked as already tagged for a function
-- are tagged by inserting runtime checks.
checkFunctionArgTags :: SDoc -> Id -> [Id] -> FCode ()
checkFunctionArgTags :: SDoc -> Id -> [Id] -> FCode ()
checkFunctionArgTags SDoc
msg Id
f [Id]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
  FCode () -> Maybe [CbvMark] -> ([CbvMark] -> FCode ()) -> FCode ()
forall b a. b -> Maybe a -> (a -> b) -> b
onJust (() -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f) (([CbvMark] -> FCode ()) -> FCode ())
-> ([CbvMark] -> FCode ()) -> FCode ()
forall a b. (a -> b) -> a -> b
$ \[CbvMark]
marks -> do
    -- Only check args marked as strict, and only lifted ones.
    let cbv_args :: [Id]
cbv_args = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
isLiftedRuntimeRep (Type -> Bool) -> (Id -> Type) -> Id -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType) ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Id] -> [Id]
forall a. [Bool] -> [a] -> [a]
filterByList ((CbvMark -> Bool) -> [CbvMark] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map CbvMark -> Bool
isMarkedCbv [CbvMark]
marks) [Id]
args
    -- Get their (cmm) address
    [CgIdInfo]
arg_infos <- (Id -> FCode CgIdInfo) -> [Id] -> FCode [CgIdInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Id -> FCode CgIdInfo
getCgIdInfo [Id]
cbv_args
    let arg_cmms :: [CmmExpr]
arg_cmms = (CgIdInfo -> CmmExpr) -> [CgIdInfo] -> [CmmExpr]
forall a b. (a -> b) -> [a] -> [b]
map CgIdInfo -> CmmExpr
idInfoToAmode [CgIdInfo]
arg_infos
    (CmmExpr -> FCode ()) -> [CmmExpr] -> FCode ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (String -> CmmExpr -> FCode ()
emitTagAssertion (SDoc -> String
forall a. Outputable a => a -> String
showPprUnsafe SDoc
msg))  ([CmmExpr]
arg_cmms)

-- | Check all required-tagged arguments of a constructor are tagged *at compile time*.
checkConArgsStatic :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsStatic :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsStatic SDoc
msg DataCon
con [StgArg]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
  let marks :: [StrictnessMark]
marks = (() :: Constraint) => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con
  (StrictnessMark -> StgArg -> FCode ())
-> [StrictnessMark] -> [StgArg] -> FCode ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic SDoc
msg) [StrictnessMark]
marks [StgArg]
args

-- Check all required arguments of a constructor are tagged.
-- Possible by emitting checks at runtime.
checkConArgsDyn :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsDyn :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsDyn SDoc
msg DataCon
con [StgArg]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
  let marks :: [StrictnessMark]
marks = (() :: Constraint) => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con
  (CbvMark -> StgArg -> FCode ())
-> [CbvMark] -> [StgArg] -> FCode ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SDoc -> CbvMark -> StgArg -> FCode ()
checkArg SDoc
msg) ((StrictnessMark -> CbvMark) -> [StrictnessMark] -> [CbvMark]
forall a b. (a -> b) -> [a] -> [b]
map StrictnessMark -> CbvMark
cbvFromStrictMark [StrictnessMark]
marks) [StgArg]
args

whenCheckTags :: FCode () -> FCode ()
whenCheckTags :: FCode () -> FCode ()
whenCheckTags FCode ()
act = do
  Bool
check_tags <- StgToCmmConfig -> Bool
stgToCmmDoTagCheck (StgToCmmConfig -> Bool) -> FCode StgToCmmConfig -> FCode Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FCode StgToCmmConfig
getStgToCmmConfig
  Bool -> FCode () -> FCode ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
check_tags FCode ()
act

-- | Call barf if we failed to predict a tag correctly.
-- This is immensly useful when debugging issues in tag inference
-- as it will result in a program abort when we encounter an invalid
-- call/heap object, rather than leaving it be and segfaulting arbitrary
-- or producing invalid results.
-- We check if either:
-- * A tag is present
-- * Or the object is a PAP (for which zero is the proper tag)
emitTagAssertion :: String -> CmmExpr -> FCode ()
emitTagAssertion :: String -> CmmExpr -> FCode ()
emitTagAssertion String
onWhat CmmExpr
fun = do
  { Platform
platform <- FCode Platform
getPlatform
  ; BlockId
lret <- FCode BlockId
forall (m :: * -> *). MonadUnique m => m BlockId
newBlockId
  ; BlockId
lno_tag <- FCode BlockId
forall (m :: * -> *). MonadUnique m => m BlockId
newBlockId
  ; BlockId
lbarf <- FCode BlockId
forall (m :: * -> *). MonadUnique m => m BlockId
newBlockId
  -- Check for presence of any tag.
  ; CmmAGraph -> FCode ()
emit (CmmAGraph -> FCode ()) -> CmmAGraph -> FCode ()
forall a b. (a -> b) -> a -> b
$ CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmAGraph
mkCbranch (Platform -> CmmExpr -> CmmExpr
cmmIsTagged Platform
platform CmmExpr
fun)
                     BlockId
lret BlockId
lno_tag (Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True)
  -- If there is no tag check if we are dealing with a PAP
  ; BlockId -> FCode ()
emitLabel BlockId
lno_tag
  ; FastString -> FCode ()
emitComment (String -> FastString
mkFastString String
"closereTypeCheck")
  ; CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag CmmExpr
fun BlockId
lbarf BlockId
lret

  ; BlockId -> FCode ()
emitLabel BlockId
lbarf
  ; String -> FCode ()
emitBarf (String
"Tag inference failed on:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
onWhat)
  ; BlockId -> FCode ()
emitLabel BlockId
lret
  }

-- | Jump to the first block if the argument closure is subject
--   to tagging requirements. Otherwise jump to the 2nd one.
needsArgTag :: CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag :: CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag CmmExpr
closure BlockId
fail BlockId
lpass = do
  Profile
profile <- FCode Profile
getProfile
  Bool
align_check <- StgToCmmConfig -> Bool
stgToCmmAlignCheck (StgToCmmConfig -> Bool) -> FCode StgToCmmConfig -> FCode Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FCode StgToCmmConfig
getStgToCmmConfig
  let clo_ty_e :: CmmExpr
clo_ty_e = Profile -> Bool -> CmmExpr -> CmmExpr
cmmGetClosureType Profile
profile Bool
align_check CmmExpr
closure
  -- The ENTER macro doesn't evaluate FUN/PAP/BCO objects. So we
  -- have to accept them not being tagged. See #21193
  -- See Note [TagInfo of functions]
  let targets :: SwitchTargets
targets = Bool
-> (Integer, Integer)
-> Maybe BlockId
-> Map Integer BlockId
-> SwitchTargets
mkSwitchTargets
        Bool
False
        (INVALID_OBJECT, N_CLOSURE_TYPES)
        (BlockId -> Maybe BlockId
forall a. a -> Maybe a
Just BlockId
fail)
        ([(Integer, BlockId)] -> Map Integer BlockId
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(PAP,lpass)
                    ,(BCO,lpass)
                    ,(FUN,lpass)
                    ,(FUN_1_0,lpass)
                    ,(FUN_0_1,lpass)
                    ,(FUN_2_0,lpass)
                    ,(FUN_1_1,lpass)
                    ,(FUN_0_2,lpass)
                    ,(FUN_STATIC,lpass)
                    ])

  CmmAGraph -> FCode ()
emit (CmmAGraph -> FCode ()) -> CmmAGraph -> FCode ()
forall a b. (a -> b) -> a -> b
$ CmmExpr -> SwitchTargets -> CmmAGraph
mkSwitch CmmExpr
clo_ty_e SwitchTargets
targets

  CmmAGraph -> FCode ()
emit (CmmAGraph -> FCode ()) -> CmmAGraph -> FCode ()
forall a b. (a -> b) -> a -> b
$ BlockId -> CmmAGraph
mkBranch BlockId
lpass


emitArgTagCheck :: SDoc -> [CbvMark] -> [Id] -> FCode ()
emitArgTagCheck :: SDoc -> [CbvMark] -> [Id] -> FCode ()
emitArgTagCheck SDoc
info [CbvMark]
marks [Id]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
  Module
mod <- FCode Module
getModuleName
  let cbv_args :: [Id]
cbv_args = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
isLiftedRuntimeRep (Type -> Bool) -> (Id -> Type) -> Id -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType) ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Id] -> [Id]
forall a. [Bool] -> [a] -> [a]
filterByList ((CbvMark -> Bool) -> [CbvMark] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map CbvMark -> Bool
isMarkedCbv [CbvMark]
marks) [Id]
args
  [CgIdInfo]
arg_infos <- (Id -> FCode CgIdInfo) -> [Id] -> FCode [CgIdInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Id -> FCode CgIdInfo
getCgIdInfo [Id]
cbv_args
  let arg_cmms :: [CmmExpr]
arg_cmms = (CgIdInfo -> CmmExpr) -> [CgIdInfo] -> [CmmExpr]
forall a b. (a -> b) -> [a] -> [b]
map CgIdInfo -> CmmExpr
idInfoToAmode [CgIdInfo]
arg_infos
      mk_msg :: Id -> String
mk_msg Id
arg = SDoc -> String
forall a. Outputable a => a -> String
showPprUnsafe (String -> SDoc
text String
"Untagged arg:" SDoc -> SDoc -> SDoc
<> (Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
mod) SDoc -> SDoc -> SDoc
<> Char -> SDoc
char Char
':' SDoc -> SDoc -> SDoc
<> SDoc
info SDoc -> SDoc -> SDoc
<+> Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
arg)
  (String -> CmmExpr -> FCode ())
-> [String] -> [CmmExpr] -> FCode ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ String -> CmmExpr -> FCode ()
emitTagAssertion ((Id -> String) -> [Id] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Id -> String
mk_msg [Id]
args) ([CmmExpr]
arg_cmms)

taggedCgInfo :: CgIdInfo -> Bool
taggedCgInfo :: CgIdInfo -> Bool
taggedCgInfo CgIdInfo
cg_info
  = case LambdaFormInfo
lf of
      LFCon {} -> Bool
True
      LFReEntrant {} -> Bool
True
      LFUnlifted {} -> Bool
True
      LFThunk {} -> Bool
False
      LFUnknown {} -> Bool
False
      LambdaFormInfo
LFLetNoEscape -> String -> Bool
forall a. String -> a
panic String
"Let no escape binding passed to top level con"
  where
    lf :: LambdaFormInfo
lf = CgIdInfo -> LambdaFormInfo
cg_lf CgIdInfo
cg_info

-- Check that one argument is properly tagged.
checkArg :: SDoc -> CbvMark -> StgArg -> FCode ()
checkArg :: SDoc -> CbvMark -> StgArg -> FCode ()
checkArg SDoc
_ CbvMark
NotMarkedCbv StgArg
_ = () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkArg SDoc
msg CbvMark
MarkedCbv StgArg
arg = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$
  case StgArg
arg of
    StgLitArg Literal
_ -> () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    StgVarArg Id
v -> do
      CgIdInfo
info <- Id -> FCode CgIdInfo
getCgIdInfo Id
v
      if CgIdInfo -> Bool
taggedCgInfo CgIdInfo
info
          then () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          else case (CgIdInfo -> CgLoc
cg_loc CgIdInfo
info) of
            CmmLoc CmmExpr
loc -> String -> CmmExpr -> FCode ()
emitTagAssertion (SDoc -> String
forall a. Outputable a => a -> String
showPprUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ SDoc
msg SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"arg:" SDoc -> SDoc -> SDoc
<> StgArg -> SDoc
forall a. Outputable a => a -> SDoc
ppr StgArg
arg) CmmExpr
loc
            LneLoc {} -> String -> FCode ()
forall a. String -> a
panic String
"LNE-arg"

-- Check that argument is properly tagged.
checkArgStatic :: SDoc -> StrictnessMark  -> StgArg -> FCode ()
checkArgStatic :: SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic SDoc
_   StrictnessMark
NotMarkedStrict StgArg
_ = () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkArgStatic SDoc
msg StrictnessMark
MarkedStrict StgArg
arg = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$
  case StgArg
arg of
    StgLitArg Literal
_ -> () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    StgVarArg Id
v -> do
      CgIdInfo
info <- Id -> FCode CgIdInfo
getCgIdInfo Id
v
      if CgIdInfo -> Bool
taggedCgInfo CgIdInfo
info
          then () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          else String -> SDoc -> FCode ()
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Arg not tagged as expectd" (SDoc -> SDoc
forall a. Outputable a => a -> SDoc
ppr SDoc
msg SDoc -> SDoc -> SDoc
<+> StgArg -> SDoc
forall a. Outputable a => a -> SDoc
ppr StgArg
arg)