{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}

{-# OPTIONS_GHC -Wno-unused-top-binds #-}

module IfSat.Plugin.Compat
  ( wrapTcS, getRestoreTcS )
  where

-- base

import Unsafe.Coerce
  ( unsafeCoerce )

-- ghc

#if MIN_VERSION_ghc(9,4,0)
import GHC.Tc.Solver.InertSet
  ( WorkList, InertSet )
#endif
import GHC.Tc.Solver.Monad
  ( TcS
#if MIN_VERSION_ghc(9,1,0)
  , TcLevel, wrapTcS
#endif
#if !MIN_VERSION_ghc(9,4,0)
  , WorkList, InertSet
#endif
  )
import GHC.Tc.Types
  ( TcM, TcRef )
import GHC.Tc.Types.Evidence
  ( EvBindsVar(..) )

-- ghc-tcplugin-api

import GHC.TcPlugin.API
  ( readTcRef, writeTcRef )

--------------------------------------------------------------------------------


-- | Capture the current 'TcS' state, returning an action which restores

-- the fields of 'TcSEnv' as appropriate after running a test-run

-- of 'solveSimpleWanteds' and deciding to backtrack.

getRestoreTcS :: TcS (TcS ())
getRestoreTcS :: TcS (TcS ())
getRestoreTcS = do
  ShimTcSEnv
shim_tcs_env <- TcS ShimTcSEnv
getShimTcSEnv
  let ev_binds_var :: EvBindsVar
ev_binds_var   = ShimTcSEnv -> EvBindsVar
shim_tcs_ev_binds ShimTcSEnv
shim_tcs_env
      unif_var :: TcRef Int
unif_var       = ShimTcSEnv -> TcRef Int
shim_tcs_unified  ShimTcSEnv
shim_tcs_env
#if MIN_VERSION_ghc(9,1,0)
      unif_lvl_var :: TcRef (Maybe TcLevel)
unif_lvl_var   = ShimTcSEnv -> TcRef (Maybe TcLevel)
shim_tcs_unif_lvl ShimTcSEnv
shim_tcs_env
#endif
      unit_count_var :: TcRef Int
unit_count_var = ShimTcSEnv -> TcRef Int
shim_tcs_count    ShimTcSEnv
shim_tcs_env
  TcM (TcS ()) -> TcS (TcS ())
forall a. TcM a -> TcS a
wrapTcS (TcM (TcS ()) -> TcS (TcS ())) -> TcM (TcS ()) -> TcS (TcS ())
forall a b. (a -> b) -> a -> b
$ do
    IOEnv (Env TcGblEnv TcLclEnv) ()
restore_evBinds <- case EvBindsVar
ev_binds_var of
      EvBindsVar { ebv_binds :: EvBindsVar -> IORef EvBindMap
ebv_binds = IORef EvBindMap
ev_binds_ref
                 , ebv_tcvs :: EvBindsVar -> IORef CoVarSet
ebv_tcvs  = IORef CoVarSet
ev_cvs_ref } ->
        do EvBindMap
ev_binds <- IORef EvBindMap -> TcRnIf TcGblEnv TcLclEnv EvBindMap
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef EvBindMap
ev_binds_ref
           CoVarSet
ev_cvs   <- IORef CoVarSet -> TcRnIf TcGblEnv TcLclEnv CoVarSet
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef CoVarSet
ev_cvs_ref
           IOEnv (Env TcGblEnv TcLclEnv) ()
-> IOEnv (Env TcGblEnv TcLclEnv) (IOEnv (Env TcGblEnv TcLclEnv) ())
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return do
             IORef EvBindMap -> EvBindMap -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef IORef EvBindMap
ev_binds_ref EvBindMap
ev_binds
             IORef CoVarSet -> CoVarSet -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef IORef CoVarSet
ev_cvs_ref   CoVarSet
ev_cvs
      CoEvBindsVar { ebv_tcvs :: EvBindsVar -> IORef CoVarSet
ebv_tcvs = IORef CoVarSet
ev_cvs_ref } ->
        do CoVarSet
ev_cvs   <- IORef CoVarSet -> TcRnIf TcGblEnv TcLclEnv CoVarSet
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef CoVarSet
ev_cvs_ref
           IOEnv (Env TcGblEnv TcLclEnv) ()
-> IOEnv (Env TcGblEnv TcLclEnv) (IOEnv (Env TcGblEnv TcLclEnv) ())
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return do
             IORef CoVarSet -> CoVarSet -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef IORef CoVarSet
ev_cvs_ref   CoVarSet
ev_cvs

    Int
unif         <- TcRef Int -> TcRnIf TcGblEnv TcLclEnv Int
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef TcRef Int
unif_var
#if MIN_VERSION_ghc(9,1,0)
    Maybe TcLevel
unif_lvl     <- TcRef (Maybe TcLevel) -> TcRnIf TcGblEnv TcLclEnv (Maybe TcLevel)
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef TcRef (Maybe TcLevel)
unif_lvl_var
#endif
    Int
count        <- TcRef Int -> TcRnIf TcGblEnv TcLclEnv Int
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef TcRef Int
unit_count_var
    TcS () -> TcM (TcS ())
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcS () -> TcM (TcS ())) -> TcS () -> TcM (TcS ())
forall a b. (a -> b) -> a -> b
$ IOEnv (Env TcGblEnv TcLclEnv) () -> TcS ()
forall a. TcM a -> TcS a
wrapTcS (IOEnv (Env TcGblEnv TcLclEnv) () -> TcS ())
-> IOEnv (Env TcGblEnv TcLclEnv) () -> TcS ()
forall a b. (a -> b) -> a -> b
$ do
      IOEnv (Env TcGblEnv TcLclEnv) ()
restore_evBinds
      TcRef Int -> Int -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef TcRef Int
unif_var       Int
unif
#if MIN_VERSION_ghc(9,1,0)
      TcRef (Maybe TcLevel)
-> Maybe TcLevel -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef TcRef (Maybe TcLevel)
unif_lvl_var   Maybe TcLevel
unif_lvl
#endif
      TcRef Int -> Int -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef TcRef Int
unit_count_var Int
count

  -- NB: no need to reset 'tcs_inerts' or 'tcs_worklist', because

  -- 'solveSimpleWanteds' calls 'nestTcS', which appropriately resets

  -- both of those fields.


#if !MIN_VERSION_ghc(9,1,0)
wrapTcS :: TcM a -> TcS a
wrapTcS = unsafeCoerce const
#endif

-- Obtain the 'TcSEnv' underlying the 'TcS' monad (in the form of a 'ShimTcSEnv').

getShimTcSEnv :: TcS ShimTcSEnv
getShimTcSEnv :: TcS ShimTcSEnv
getShimTcSEnv = (ShimTcSEnv -> TcM ShimTcSEnv) -> TcS ShimTcSEnv
forall a b. a -> b
unsafeCoerce ( ShimTcSEnv -> TcM ShimTcSEnv
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return :: ShimTcSEnv -> TcM ShimTcSEnv )

-- | A shim copy of "GHC.Tc.Solver.Monad.TcSEnv", to work around the

-- fact that it isn't exported.

--

-- Needs to be manually kept in sync with 'TcSEnv' to avoid segfaults due

-- to the use of 'unsafeCoerce' in 'getShimTcSEnv'.

data ShimTcSEnv
  = ShimTcSEnv
  { ShimTcSEnv -> EvBindsVar
shim_tcs_ev_binds           :: EvBindsVar
  , ShimTcSEnv -> TcRef Int
shim_tcs_unified            :: TcRef Int
#if MIN_VERSION_ghc(9,1,0)
  , ShimTcSEnv -> TcRef (Maybe TcLevel)
shim_tcs_unif_lvl           :: TcRef (Maybe TcLevel)
#endif
  , ShimTcSEnv -> TcRef Int
shim_tcs_count              :: TcRef Int
  , ShimTcSEnv -> TcRef InertSet
shim_tcs_inerts             :: TcRef InertSet
#if MIN_VERSION_ghc(9,3,0)
  , ShimTcSEnv -> Bool
shim_tcs_abort_on_insoluble :: Bool
#endif
  , ShimTcSEnv -> TcRef WorkList
shim_tcs_worklist           :: TcRef WorkList
  }