{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE CPP #-}
module Graph.Trace
  ( plugin
  , module DT
  , module Trace
  ) where

import           Control.Monad (when)
import           Control.Monad.IO.Class (liftIO)
import           Control.Monad.Trans.State.Strict
import           Control.Monad.Trans.Writer.CPS
import qualified Data.Generics as Syb
import qualified Data.Set as S

import           Graph.Trace.Internal.Predicates (addConstraintToSig, removeConstraints)
import qualified Graph.Trace.Internal.GhcFacade as Ghc
import           Graph.Trace.Internal.Instrument (modifyClsInstDecl, modifyTyClDecl, modifyValBinds)
import           Graph.Trace.Internal.Solver (tcPlugin)
import           Graph.Trace.Internal.Types as DT
import           Graph.Trace.Internal.Trace as Trace

plugin :: Ghc.Plugin
plugin :: Plugin
plugin =
  Plugin
Ghc.defaultPlugin
    { pluginRecompile :: [[Char]] -> IO PluginRecompile
Ghc.pluginRecompile = [[Char]] -> IO PluginRecompile
Ghc.purePlugin
    , tcPlugin :: TcPlugin
Ghc.tcPlugin = \[[Char]]
_ -> forall a. a -> Maybe a
Just TcPlugin
tcPlugin
    , renamedResultAction :: [[Char]]
-> TcGblEnv -> HsGroup GhcRn -> TcM (TcGblEnv, HsGroup GhcRn)
Ghc.renamedResultAction = [[Char]]
-> TcGblEnv -> HsGroup GhcRn -> TcM (TcGblEnv, HsGroup GhcRn)
renamedResultAction
    }

findImportedModule :: String -> Ghc.TcM Ghc.Module
findImportedModule :: [Char] -> TcM Module
findImportedModule [Char]
moduleName = do
  HscEnv
hscEnv <- forall gbl lcl. TcRnIf gbl lcl HscEnv
Ghc.getTopEnv
  FindResult
result <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$
#if MIN_VERSION_ghc(9,4,0)
    Ghc.findImportedModule hscEnv (Ghc.mkModuleName moduleName) Ghc.NoPkgQual
#else
    HscEnv -> ModuleName -> Maybe FastString -> IO FindResult
Ghc.findImportedModule HscEnv
hscEnv ([Char] -> ModuleName
Ghc.mkModuleName [Char]
moduleName) forall a. Maybe a
Nothing
#endif
  case FindResult
result of
    Ghc.Found ModLocation
_ Module
m -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Module
m
    FindResult
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"unable to find module: " forall a. Semigroup a => a -> a -> a
<> [Char]
moduleName

warnAboutOptimizations :: Ghc.TcM ()
warnAboutOptimizations :: TcM ()
warnAboutOptimizations = do
  EnumSet GeneralFlag
generalFlags <- DynFlags -> EnumSet GeneralFlag
Ghc.generalFlags forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). HasDynFlags m => m DynFlags
Ghc.getDynFlags
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Enum a => a -> EnumSet a -> Bool
Ghc.enumSetMember GeneralFlag
Ghc.Opt_FullLaziness EnumSet GeneralFlag
generalFlags) forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ [Char] -> IO ()
putStrLn [Char]
" * Full laziness is enabled: it's generally recommended to disable this optimization when using graph-trace. Use the -fno-full-laziness GHC option to disable it."
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Enum a => a -> EnumSet a -> Bool
Ghc.enumSetMember GeneralFlag
Ghc.Opt_CSE EnumSet GeneralFlag
generalFlags) forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ [Char] -> IO ()
putStrLn [Char]
" * Common sub-expression elimination is enabled: it's generally recommended to disable this optimization when using graph-trace. Use the -fno-cse GHC option to disable it."

isMonomorphismRestrictionOn :: Ghc.TcM Bool
isMonomorphismRestrictionOn :: TcM Bool
isMonomorphismRestrictionOn =
  Extension -> DynFlags -> Bool
Ghc.xopt Extension
Ghc.MonomorphismRestriction forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). HasDynFlags m => m DynFlags
Ghc.getDynFlags

renamedResultAction
  :: [Ghc.CommandLineOption]
  -> Ghc.TcGblEnv
  -> Ghc.HsGroup Ghc.GhcRn
  -> Ghc.TcM (Ghc.TcGblEnv, Ghc.HsGroup Ghc.GhcRn)
renamedResultAction :: [[Char]]
-> TcGblEnv -> HsGroup GhcRn -> TcM (TcGblEnv, HsGroup GhcRn)
renamedResultAction [[Char]]
cmdLineOptions TcGblEnv
tcGblEnv
    hsGroup :: HsGroup GhcRn
hsGroup@Ghc.HsGroup{hs_valds :: forall p. HsGroup p -> HsValBinds p
Ghc.hs_valds = Ghc.XValBindsLR{}}
    = do
  TcM ()
warnAboutOptimizations

  Module
debugTypesModule <- [Char] -> TcM Module
findImportedModule [Char]
"Graph.Trace.Internal.Types"
  Module
debugTraceModule <- [Char] -> TcM Module
findImportedModule [Char]
"Graph.Trace.Internal.Trace"

  Name
traceMutePredName    <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkClsOcc [Char]
"TraceMute")
  Name
traceDeepPredName    <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkClsOcc [Char]
"TraceDeep")
  Name
traceDeepKeyPredName <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkClsOcc [Char]
"TraceDeepKey")
  Name
tracePredName        <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkClsOcc [Char]
"Trace")
  Name
traceKeyPredName     <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkClsOcc [Char]
"TraceKey")
  Name
traceInertPredName   <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkClsOcc [Char]
"TraceInert")
  Name
entryName            <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTraceModule ([Char] -> OccName
Ghc.mkVarOcc [Char]
"entry")
  Name
debugContextName     <- forall a b. Module -> OccName -> TcRnIf a b Name
Ghc.lookupOrig Module
debugTypesModule ([Char] -> OccName
Ghc.mkTcOcc [Char]
"DebugContext")

  let debugNames :: DebugNames
debugNames = DebugNames{Name
debugContextName :: Name
entryName :: Name
traceInertPredName :: Name
traceKeyPredName :: Name
tracePredName :: Name
traceDeepKeyPredName :: Name
traceDeepPredName :: Name
traceMutePredName :: Name
debugContextName :: Name
entryName :: Name
traceInertPredName :: Name
traceKeyPredName :: Name
tracePredName :: Name
traceDeepKeyPredName :: Name
traceDeepPredName :: Name
traceMutePredName :: Name
..}

  -- If the "trace-all" option is passed, add the Debug predicate to all
  -- function signatures.
  let traceAllFlag :: Bool
traceAllFlag = [Char]
"trace-all" forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [[Char]]
cmdLineOptions
      (hsGroup' :: HsGroup GhcRn
hsGroup'@Ghc.HsGroup
        { hs_valds :: forall p. HsGroup p -> HsValBinds p
Ghc.hs_valds = HsValBinds GhcRn
valBinds --Ghc.XValBindsLR (Ghc.NValBinds binds sigs)
        , hs_tyclds :: forall p. HsGroup p -> [TyClGroup p]
Ghc.hs_tyclds = [TyClGroup GhcRn]
tyClGroups
        }, Map Name (Maybe FastString, Propagation)
nameMap) = forall w a. Monoid w => Writer w a -> (a, w)
runWriter
          forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
Syb.mkM (DebugNames
-> Bool
-> Sig GhcRn
-> Writer (Map Name (Maybe FastString, Propagation)) (Sig GhcRn)
addConstraintToSig DebugNames
debugNames Bool
traceAllFlag)
              forall (m :: * -> *). Monad m => GenericM m -> GenericM m
`Syb.everywhereM` HsGroup GhcRn
hsGroup

  -- process value bindings
  (HsValBinds GhcRn
valBinds', Set Name
patBindNames) <- (forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` forall a. Set a
S.empty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
Syb.mkM (DebugNames
-> Map Name (Maybe FastString, Propagation)
-> NHsValBindsLR GhcRn
-> WriterT
     (Set Name)
     (StateT (Set Name) (IOEnv (Env TcGblEnv TcLclEnv)))
     (NHsValBindsLR GhcRn)
modifyValBinds DebugNames
debugNames Map Name (Maybe FastString, Propagation)
nameMap)
    forall (m :: * -> *). Monad m => GenericM m -> GenericM m
`Syb.everywhereM`
      HsValBinds GhcRn
valBinds

  -- process type class decls and instances
  -- TODO Only need to traverse with modifyValBinds. Others are not applied deeply
  ([TyClGroup GhcRn]
tyClGroups', Set Name
tyClPatBindNames) <- (forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` forall a. Set a
S.empty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
Syb.mkM (DebugNames
-> Map Name (Maybe FastString, Propagation)
-> ClsInstDecl GhcRn
-> WriterT
     (Set Name)
     (StateT (Set Name) (IOEnv (Env TcGblEnv TcLclEnv)))
     (ClsInstDecl GhcRn)
modifyClsInstDecl DebugNames
debugNames Map Name (Maybe FastString, Propagation)
nameMap)
    forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`Syb.extM`
      DebugNames
-> Map Name (Maybe FastString, Propagation)
-> TyClDecl GhcRn
-> WriterT
     (Set Name)
     (StateT (Set Name) (IOEnv (Env TcGblEnv TcLclEnv)))
     (TyClDecl GhcRn)
modifyTyClDecl DebugNames
debugNames Map Name (Maybe FastString, Propagation)
nameMap
    forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`Syb.extM`
      DebugNames
-> Map Name (Maybe FastString, Propagation)
-> NHsValBindsLR GhcRn
-> WriterT
     (Set Name)
     (StateT (Set Name) (IOEnv (Env TcGblEnv TcLclEnv)))
     (NHsValBindsLR GhcRn)
modifyValBinds DebugNames
debugNames Map Name (Maybe FastString, Propagation)
nameMap
    forall (m :: * -> *). Monad m => GenericM m -> GenericM m
`Syb.everywhereM`
      [TyClGroup GhcRn]
tyClGroups

  Bool
mmrOn <- TcM Bool
isMonomorphismRestrictionOn

  -- remove predicates from signatures for pattern bound ids if monomorphism
  -- restriction is on, otherwise compilation will fail.
  let (HsValBinds GhcRn
valBinds'', [TyClGroup GhcRn]
tyClGroups'') =
        if Bool
mmrOn
           then ( forall a. Data a => DebugNames -> Set Name -> a -> a
removeConstraints DebugNames
debugNames Set Name
patBindNames HsValBinds GhcRn
valBinds'
                , forall a. Data a => DebugNames -> Set Name -> a -> a
removeConstraints DebugNames
debugNames Set Name
tyClPatBindNames [TyClGroup GhcRn]
tyClGroups'
                )
           else (HsValBinds GhcRn
valBinds', [TyClGroup GhcRn]
tyClGroups')

  forall (f :: * -> *) a. Applicative f => a -> f a
pure ( TcGblEnv
tcGblEnv
       , HsGroup GhcRn
hsGroup' { hs_valds :: HsValBinds GhcRn
Ghc.hs_valds = HsValBinds GhcRn
valBinds''
                  , hs_tyclds :: [TyClGroup GhcRn]
Ghc.hs_tyclds = [TyClGroup GhcRn]
tyClGroups''
                  }
       )

renamedResultAction [[Char]]
_ TcGblEnv
tcGblEnv HsGroup GhcRn
group = forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcGblEnv
tcGblEnv, HsGroup GhcRn
group)