{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TupleSections #-}

module IfSat.Plugin
  ( plugin )
  where

-- base

import Control.Monad
  ( filterM )
import Data.Foldable
  ( for_ )
import Data.Maybe
  ( catMaybes, mapMaybe )

-- ghc

import GHC.Plugins
  hiding ( TcPlugin, (<>) )
import GHC.Data.Bag
  ( unitBag )
#if MIN_VERSION_ghc(9,7,0)
import GHC.Tc.Solver.Solve
  ( solveSimpleGivens, solveSimpleWanteds )
#else
import GHC.Tc.Solver.Interact
  ( solveSimpleGivens, solveSimpleWanteds )
#endif
import GHC.Tc.Solver.Monad
  ( runTcSWithEvBinds, traceTcS )
import GHC.Tc.Types
  ( TcM )
import GHC.Tc.Types.Constraint
  ( isEmptyWC, CtEvidence (..), ctEvEvId )
import GHC.Tc.Utils.TcType
  ( MetaDetails(..), metaTyVarRef
  , tyCoVarsOfTypeList
  )
import GHC.Tc.Utils.TcMType
  ( isUnfilledMetaTyVar, newTcEvBinds )

-- ghc-tcplugin-api

import GHC.TcPlugin.API
import GHC.TcPlugin.API.Internal
  ( unsafeLiftTcM )

-- if-instance

import IfSat.Plugin.Compat
  ( wrapTcS, getRestoreTcS )

--------------------------------------------------------------------------------

-- Plugin definition.


-- | A type-checking plugin that solves @ct_l || ct_r@ constraints.

-- This allows users to branch on whether @ct_l@ is satisfied.

--

-- To use this plugin, add @{-# OPTIONS_GHC -fplugin=IfSat.Plugin #-}@

-- to your module header.

--

-- A @ct_l || ct_r@  instance is solved by trying to solve @ct_l@:

--

--   - if solving succeeds, the 'Data.Constraint.If.dispatch' function will

--     pick the first branch,

--   - otherwise, 'Data.Constraint.If.dispatch' will pick the second branch.

--

-- This means that the branch selection occurs precisely at the moment

-- at which we solve the @ct_l || ct_r@  constraint.

-- See the documentation of 'Data.Constraint.If.dispatch' for more information.

plugin :: Plugin
plugin :: Plugin
plugin =
  Plugin
defaultPlugin
    { tcPlugin        = \ [String]
_args -> TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> TcPlugin
mkTcPlugin TcPlugin
ifSatTcPlugin
    , pluginRecompile = purePlugin
    }

ifSatTcPlugin :: TcPlugin
ifSatTcPlugin :: TcPlugin
ifSatTcPlugin =
  TcPlugin
    { tcPluginInit :: TcPluginM 'Init PluginDefs
tcPluginInit    = TcPluginM 'Init PluginDefs
initPlugin
    , tcPluginSolve :: PluginDefs -> TcPluginSolver
tcPluginSolve   = PluginDefs -> TcPluginSolver
solver
    , tcPluginRewrite :: PluginDefs -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter
    , tcPluginStop :: PluginDefs -> TcPluginM 'Stop ()
tcPluginStop    = \ PluginDefs
_ -> () -> TcPluginM 'Stop ()
forall a. a -> TcPluginM 'Stop a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    }

--------------------------------------------------------------------------------

-- Plugin initialisation.


data PluginDefs
  = PluginDefs
    { PluginDefs -> Class
orClass    :: !Class
    , PluginDefs -> TyCon
isSatTyCon :: !TyCon
    }

findModule :: MonadTcPlugin m => String -> m Module
findModule :: forall (m :: * -> *). MonadTcPlugin m => String -> m Module
findModule String
modName = do
  let mod_name :: ModuleName
mod_name = String -> ModuleName
mkModuleName String
modName
  PkgQual
pkg_qual   <- ModuleName -> Maybe FastString -> m PkgQual
forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> Maybe FastString -> m PkgQual
resolveImport      ModuleName
mod_name Maybe FastString
forall a. Maybe a
Nothing
  FindResult
findResult <- ModuleName -> PkgQual -> m FindResult
forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> PkgQual -> m FindResult
findImportedModule ModuleName
mod_name PkgQual
pkg_qual
  case FindResult
findResult of
    Found ModLocation
_ Module
res     -> Module -> m Module
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Module
res
    FoundMultiple [(Module, ModuleOrigin)]
_ -> String -> m Module
forall a. HasCallStack => String -> a
error (String -> m Module) -> String -> m Module
forall a b. (a -> b) -> a -> b
$ String
"IfSat plugin: found multiple modules named " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
modName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"."
    FindResult
_               -> String -> m Module
forall a. HasCallStack => String -> a
error (String -> m Module) -> String -> m Module
forall a b. (a -> b) -> a -> b
$ String
"IfSat plugin: could not find any module named " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
modName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"."

initPlugin :: TcPluginM Init PluginDefs
initPlugin :: TcPluginM 'Init PluginDefs
initPlugin = do
  Module
ifSatModule <- String -> TcPluginM 'Init Module
forall (m :: * -> *). MonadTcPlugin m => String -> m Module
findModule String
"Data.Constraint.If"
  Class
orClass     <- Name -> TcPluginM 'Init Class
forall (m :: * -> *). MonadTcPlugin m => Name -> m Class
tcLookupClass (Name -> TcPluginM 'Init Class)
-> TcPluginM 'Init Name -> TcPluginM 'Init Class
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM 'Init Name
forall (m :: * -> *).
MonadTcPlugin m =>
Module -> OccName -> m Name
lookupOrig Module
ifSatModule ( String -> OccName
mkClsOcc String
"||"    )
  TyCon
isSatTyCon  <- Name -> TcPluginM 'Init TyCon
forall (m :: * -> *). MonadTcPlugin m => Name -> m TyCon
tcLookupTyCon (Name -> TcPluginM 'Init TyCon)
-> TcPluginM 'Init Name -> TcPluginM 'Init TyCon
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM 'Init Name
forall (m :: * -> *).
MonadTcPlugin m =>
Module -> OccName -> m Name
lookupOrig Module
ifSatModule ( String -> OccName
mkTcOcc  String
"IsSat" )
  PluginDefs -> TcPluginM 'Init PluginDefs
forall a. a -> TcPluginM 'Init a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PluginDefs -> TcPluginM 'Init PluginDefs)
-> PluginDefs -> TcPluginM 'Init PluginDefs
forall a b. (a -> b) -> a -> b
$ PluginDefs { Class
orClass :: Class
orClass :: Class
orClass, TyCon
isSatTyCon :: TyCon
isSatTyCon :: TyCon
isSatTyCon }

--------------------------------------------------------------------------------

-- Constraint solving.


solver :: PluginDefs -> [ Ct ] -> [ Ct ] -> TcPluginM Solve TcPluginSolveResult
solver :: PluginDefs -> TcPluginSolver
solver PluginDefs
defs [Ct]
givens [Ct]
wanteds
  | [Ct] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Ct]
wanteds
  = TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []
  | Bool
otherwise
  = do
      String -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat solver {" ([Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
givens SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
wanteds)
      [(EvTerm, Ct)]
solveds <- [Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)])
-> TcPluginM 'Solve [Maybe (EvTerm, Ct)]
-> TcPluginM 'Solve [(EvTerm, Ct)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ct -> TcPluginM 'Solve (Maybe (EvTerm, Ct)))
-> [Ct] -> TcPluginM 'Solve [Maybe (EvTerm, Ct)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ( PluginDefs -> [Ct] -> Ct -> TcPluginM 'Solve (Maybe (EvTerm, Ct))
solveWanted PluginDefs
defs [Ct]
givens ) [Ct]
wanteds
      String -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat solver }" SDoc
forall doc. IsOutput doc => doc
empty
      TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solveds []

solveWanted :: PluginDefs -> [ Ct ] -> Ct -> TcPluginM Solve ( Maybe ( EvTerm, Ct ) )
solveWanted :: PluginDefs -> [Ct] -> Ct -> TcPluginM 'Solve (Maybe (EvTerm, Ct))
solveWanted defs :: PluginDefs
defs@( PluginDefs { Class
orClass :: PluginDefs -> Class
orClass :: Class
orClass } ) [Ct]
givens Ct
wanted
  | ClassPred Class
cls [Type
ct_l_ty, Type
ct_r_ty] <- Type -> Pred
classifyPredType ( Ct -> Type
ctPred Ct
wanted )
  , Class
cls Class -> Class -> Bool
forall a. Eq a => a -> a -> Bool
== Class
orClass
  = do
    String -> SDoc -> TcPluginM 'Solve ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat solver: found (||) constraint"
      ( Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_l_ty SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_r_ty SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ Ct -> SDoc
forall a. Outputable a => a -> SDoc
ppr Ct
wanted )
    CtEvidence
ct_l_ev <- CtLoc -> Type -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted ( Ct -> CtLoc
ctLoc Ct
wanted ) Type
ct_l_ty
    CtEvidence
ct_r_ev <- CtLoc -> Type -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted ( Ct -> CtLoc
ctLoc Ct
wanted ) Type
ct_r_ty
    let
      ct_l, ct_r :: Ct
      ct_l :: Ct
ct_l = CtEvidence -> Ct
mkNonCanonical CtEvidence
ct_l_ev
      ct_r :: Ct
ct_r = CtEvidence -> Ct
mkNonCanonical CtEvidence
ct_r_ev
      ct_l_ev_dest, ct_r_ev_dest :: TcEvDest
      ct_l_ev_dest :: TcEvDest
ct_l_ev_dest = (() :: Constraint) => CtEvidence -> TcEvDest
CtEvidence -> TcEvDest
wantedEvDest CtEvidence
ct_l_ev
      ct_r_ev_dest :: TcEvDest
ct_r_ev_dest = (() :: Constraint) => CtEvidence -> TcEvDest
CtEvidence -> TcEvDest
wantedEvDest CtEvidence
ct_r_ev

    EvBindsVar
evBindsVar <- TcPluginM 'Solve EvBindsVar
askEvBinds
    -- Start a new constraint solver run.

    TcM (Maybe (EvTerm, Ct)) -> TcPluginM 'Solve (Maybe (EvTerm, Ct))
forall a. TcM a -> TcPluginM 'Solve a
forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
unsafeLiftTcM (TcM (Maybe (EvTerm, Ct)) -> TcPluginM 'Solve (Maybe (EvTerm, Ct)))
-> TcM (Maybe (EvTerm, Ct))
-> TcPluginM 'Solve (Maybe (EvTerm, Ct))
forall a b. (a -> b) -> a -> b
$ EvBindsVar -> TcS (Maybe (EvTerm, Ct)) -> TcM (Maybe (EvTerm, Ct))
forall a. EvBindsVar -> TcS a -> TcM a
runTcSWithEvBinds EvBindsVar
evBindsVar (TcS (Maybe (EvTerm, Ct)) -> TcM (Maybe (EvTerm, Ct)))
-> TcS (Maybe (EvTerm, Ct)) -> TcM (Maybe (EvTerm, Ct))
forall a b. (a -> b) -> a -> b
$ do
      -- Add back all the Givens.

      String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: adding Givens to the inert set" ([Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
givens)
      [Ct] -> TcS ()
solveSimpleGivens [Ct]
givens

      -- Keep track of the current solver state in order to backtrack

      -- in the event that our attempt at solving 'ct_l' fails.

      [Var]
ct_l_unfilled_metas <- TcM [Var] -> TcS [Var]
forall a. TcM a -> TcS a
wrapTcS
                           (TcM [Var] -> TcS [Var]) -> TcM [Var] -> TcS [Var]
forall a b. (a -> b) -> a -> b
$ (Var -> IOEnv (Env TcGblEnv TcLclEnv) Bool) -> [Var] -> TcM [Var]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM Var -> IOEnv (Env TcGblEnv TcLclEnv) Bool
isUnfilledMetaTyVar
                           ([Var] -> TcM [Var]) -> [Var] -> TcM [Var]
forall a b. (a -> b) -> a -> b
$ Type -> [Var]
tyCoVarsOfTypeList Type
ct_l_ty
      TcS ()
restoreTcS <- TcS (TcS ())
getRestoreTcS

      -- Try to solve 'ct_l', using both Givens and top-level instances.

      WantedConstraints
residual_ct_l <- Cts -> TcS WantedConstraints
solveSimpleWanteds ( Ct -> Cts
forall a. a -> Bag a
unitBag Ct
ct_l )

      -- Now look up whether GHC has managed to produce evidence for 'ct_l'.

      Maybe EvTerm
mb_ct_l_evTerm <- EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
evBindsVar TcEvDest
ct_l_ev_dest
      Maybe EvTerm
mb_wanted_evTerm <- case Maybe EvTerm
mb_ct_l_evTerm of
        Just ( EvExpr EvExpr
ct_l_evExpr )
          | WantedConstraints -> Bool
isEmptyWC WantedConstraints
residual_ct_l
          -> do
          -- We've managed to solve 'ct_l': use the evidence and take the 'True' branch.

          String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: LHS constraint could be solved"
            ( [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat
              [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ct_l =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_l_ty
              , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ev   =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> EvExpr -> SDoc
forall a. Outputable a => a -> SDoc
ppr EvExpr
ct_l_evExpr
              ]
            )
          TcM (Maybe EvTerm) -> TcS (Maybe EvTerm)
forall a. TcM a -> TcS a
wrapTcS (TcM (Maybe EvTerm) -> TcS (Maybe EvTerm))
-> TcM (Maybe EvTerm) -> TcS (Maybe EvTerm)
forall a b. (a -> b) -> a -> b
$ ( EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> Maybe EvTerm)
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm -> TcM (Maybe EvTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PluginDefs
-> [Ct]
-> Type
-> Type
-> EvExpr
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
dispatchTrueEvTerm PluginDefs
defs [Ct]
givens Type
ct_l_ty Type
ct_r_ty EvExpr
ct_l_evExpr )
        Maybe EvTerm
_ -> do
          -- We couldn't solve 'ct_l': this means we must solve 'ct_r',

          -- to provide evidence needed for the 'False' branch.

          String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: LHS constraint could not be solved" (SDoc -> TcS ()) -> SDoc -> TcS ()
forall a b. (a -> b) -> a -> b
$
            [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ct_l =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_l_ty
                 , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"residual_ct_l =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> WantedConstraints -> SDoc
forall a. Outputable a => a -> SDoc
ppr WantedConstraints
residual_ct_l ]

          -- Reset the solver state to before we attempted to solve 'ct_l',

          -- and undo any type variable unifications that happened.

          TcS ()
restoreTcS
          TcM () -> TcS ()
forall a. TcM a -> TcS a
wrapTcS (TcM () -> TcS ()) -> TcM () -> TcS ()
forall a b. (a -> b) -> a -> b
$ [Var] -> (Var -> TcM ()) -> TcM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Var]
ct_l_unfilled_metas \ Var
meta ->
            TcRef MetaDetails -> MetaDetails -> TcM ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef ( Var -> TcRef MetaDetails
metaTyVarRef Var
meta ) MetaDetails
Flexi
          [Var]
ct_r_unfilled_metas <- TcM [Var] -> TcS [Var]
forall a. TcM a -> TcS a
wrapTcS
                               (TcM [Var] -> TcS [Var]) -> TcM [Var] -> TcS [Var]
forall a b. (a -> b) -> a -> b
$ (Var -> IOEnv (Env TcGblEnv TcLclEnv) Bool) -> [Var] -> TcM [Var]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM Var -> IOEnv (Env TcGblEnv TcLclEnv) Bool
isUnfilledMetaTyVar
                               ([Var] -> TcM [Var]) -> [Var] -> TcM [Var]
forall a b. (a -> b) -> a -> b
$ Type -> [Var]
tyCoVarsOfTypeList Type
ct_r_ty

          -- Try to solve 'ct_r', using both Givens and top-level instances.

          WantedConstraints
residual_ct_r <- Cts -> TcS WantedConstraints
solveSimpleWanteds ( Ct -> Cts
forall a. a -> Bag a
unitBag Ct
ct_r )
          Maybe EvTerm
mb_ct_r_evTerm <- EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
evBindsVar TcEvDest
ct_r_ev_dest
          case Maybe EvTerm
mb_ct_r_evTerm of
            Just ( EvExpr EvExpr
ct_r_evExpr )
              | WantedConstraints -> Bool
isEmptyWC WantedConstraints
residual_ct_r
              -> do
              -- We've managed to solve 'ct_r': use the evidence and take the 'False' branch.

              String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: RHS constraint could be solved" (SDoc -> TcS ()) -> SDoc -> TcS ()
forall a b. (a -> b) -> a -> b
$
                [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ct_r =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_r_ty
                     , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ev   =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> EvExpr -> SDoc
forall a. Outputable a => a -> SDoc
ppr EvExpr
ct_r_evExpr
                     ]
              TcM (Maybe EvTerm) -> TcS (Maybe EvTerm)
forall a. TcM a -> TcS a
wrapTcS (TcM (Maybe EvTerm) -> TcS (Maybe EvTerm))
-> TcM (Maybe EvTerm) -> TcS (Maybe EvTerm)
forall a b. (a -> b) -> a -> b
$ ( EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> Maybe EvTerm)
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm -> TcM (Maybe EvTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PluginDefs
-> [Ct]
-> Type
-> Type
-> EvExpr
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
dispatchFalseEvTerm PluginDefs
defs [Ct]
givens Type
ct_l_ty Type
ct_r_ty EvExpr
ct_r_evExpr )
            Maybe EvTerm
_ -> do
              -- We could solve neither 'ct_l' not 'ct_r'.

              -- This means we can't solve the disjunction constraint.

              String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: RHS constraint could not be solved" (SDoc -> TcS ()) -> SDoc -> TcS ()
forall a b. (a -> b) -> a -> b
$
                [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"ct_r =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_r_ty
                     , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"residualct_r =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> WantedConstraints -> SDoc
forall a. Outputable a => a -> SDoc
ppr WantedConstraints
residual_ct_r ]

              -- Reset the solver state to before we attempted to solve 'ct_r',

              -- and undo any type variable unifications that happened.

              TcS ()
restoreTcS
              TcM () -> TcS ()
forall a. TcM a -> TcS a
wrapTcS (TcM () -> TcS ()) -> TcM () -> TcS ()
forall a b. (a -> b) -> a -> b
$ [Var] -> (Var -> TcM ()) -> TcM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Var]
ct_r_unfilled_metas \ Var
meta ->
                TcRef MetaDetails -> MetaDetails -> TcM ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef ( Var -> TcRef MetaDetails
metaTyVarRef Var
meta ) MetaDetails
Flexi

              Maybe EvTerm -> TcS (Maybe EvTerm)
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EvTerm
forall a. Maybe a
Nothing
      Maybe (EvTerm, Ct) -> TcS (Maybe (EvTerm, Ct))
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (EvTerm, Ct) -> TcS (Maybe (EvTerm, Ct)))
-> Maybe (EvTerm, Ct) -> TcS (Maybe (EvTerm, Ct))
forall a b. (a -> b) -> a -> b
$ ( , Ct
wanted ) (EvTerm -> (EvTerm, Ct)) -> Maybe EvTerm -> Maybe (EvTerm, Ct)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe EvTerm
mb_wanted_evTerm
  | Bool
otherwise
  = Maybe (EvTerm, Ct) -> TcPluginM 'Solve (Maybe (EvTerm, Ct))
forall a. a -> TcPluginM 'Solve a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (EvTerm, Ct)
forall a. Maybe a
Nothing

-- | Look up whether a 'TcEvDest' has been filled with evidence.

lookupEvTerm :: EvBindsVar -> TcEvDest -> TcS ( Maybe EvTerm )
lookupEvTerm :: EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
_ ( HoleDest ( CoercionHole { ch_ref :: CoercionHole -> IORef (Maybe Coercion)
ch_ref = IORef (Maybe Coercion)
ref } ) ) = do
  Maybe Coercion
mb_co <- TcM (Maybe Coercion) -> TcS (Maybe Coercion)
forall a. TcM a -> TcS a
wrapTcS (TcM (Maybe Coercion) -> TcS (Maybe Coercion))
-> TcM (Maybe Coercion) -> TcS (Maybe Coercion)
forall a b. (a -> b) -> a -> b
$ IORef (Maybe Coercion) -> TcM (Maybe Coercion)
forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef (Maybe Coercion)
ref
  String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: coercion hole" ( Maybe Coercion -> SDoc
forall a. Outputable a => a -> SDoc
ppr Maybe Coercion
mb_co )
  case Maybe Coercion
mb_co of
    Maybe Coercion
Nothing -> Maybe EvTerm -> TcS (Maybe EvTerm)
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EvTerm
forall a. Maybe a
Nothing
    Just Coercion
co -> Maybe EvTerm -> TcS (Maybe EvTerm)
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe EvTerm -> TcS (Maybe EvTerm))
-> (EvTerm -> Maybe EvTerm) -> EvTerm -> TcS (Maybe EvTerm)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> TcS (Maybe EvTerm)) -> EvTerm -> TcS (Maybe EvTerm)
forall a b. (a -> b) -> a -> b
$ Coercion -> EvTerm
evCoercion Coercion
co
lookupEvTerm EvBindsVar
evBindsVar ( EvVarDest Var
ev_var ) = do
  EvBindMap
evBindsMap <- EvBindsVar -> TcS EvBindMap
getTcEvBindsMap EvBindsVar
evBindsVar
  let
    mb_evBind :: Maybe EvBind
    mb_evBind :: Maybe EvBind
mb_evBind = EvBindMap -> Var -> Maybe EvBind
lookupEvBind EvBindMap
evBindsMap Var
ev_var
  String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: evidence binding" ( Maybe EvBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Maybe EvBind
mb_evBind )
  case Maybe EvBind
mb_evBind of
    Maybe EvBind
Nothing      -> Maybe EvTerm -> TcS (Maybe EvTerm)
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EvTerm
forall a. Maybe a
Nothing
    Just EvBind
ev_bind -> Maybe EvTerm -> TcS (Maybe EvTerm)
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe EvTerm -> TcS (Maybe EvTerm))
-> (EvTerm -> Maybe EvTerm) -> EvTerm -> TcS (Maybe EvTerm)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> TcS (Maybe EvTerm)) -> EvTerm -> TcS (Maybe EvTerm)
forall a b. (a -> b) -> a -> b
$ EvBind -> EvTerm
eb_rhs EvBind
ev_bind

-- Evidence term for @ct_l || ct_r@ when @ct_l@ is satisfied.

--

-- dispatch =

--   \ @r

--     ( a :: ( IsSat ct_l ~ True, ct_l ) => r )

--     ( _ :: ( IsSat ct_l ~ False, IsSat ct_r ~ True, ct_r ) => r )

--   -> a ct_l_isSat_co ct_l_evTerm

dispatchTrueEvTerm :: PluginDefs -> [ Ct ] -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchTrueEvTerm :: PluginDefs
-> [Ct]
-> Type
-> Type
-> EvExpr
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
dispatchTrueEvTerm defs :: PluginDefs
defs@( PluginDefs { Class
orClass :: PluginDefs -> Class
orClass :: Class
orClass } ) [Ct]
givens Type
ct_l_ty Type
ct_r_ty EvExpr
ct_l_evTerm = do
  Name
r_name <- OccName -> TcM Name
newName ( String -> OccName
mkTyVarOcc String
"r" )
  Name
a_name <- OccName -> TcM Name
newName ( String -> OccName
mkVarOcc   String
"a" )
  let
    r, a, b :: CoreBndr
    r :: Var
r = Name -> Type -> Var
mkTyVar Name
r_name Type
liftedTypeKind
    a :: Var
a = (() :: Constraint) => Name -> Type -> Type -> Var
Name -> Type -> Type -> Var
mkLocalId Name
a_name Type
ManyTy
        ( [Type] -> Type -> Type
(() :: Constraint) => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
True, Type
ct_l_ty ] Type
r_ty )
    b :: Var
b = Type -> Type -> Var
mkWildValBinder  Type
ManyTy
        ( [Type] -> Type -> Type
(() :: Constraint) => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
False, PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_r_ty Bool
True, Type
ct_r_ty ] Type
r_ty )
    r_ty :: Type
    r_ty :: Type
r_ty = Var -> Type
mkTyVarTy Var
r
  EvTerm -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EvTerm -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm)
-> (EvExpr -> EvTerm)
-> EvExpr
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr (EvExpr -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm)
-> EvExpr -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
forall a b. (a -> b) -> a -> b
$
    DataCon -> [EvExpr] -> EvExpr
mkCoreConApps ( Class -> DataCon
classDataCon Class
orClass )
      [ Type -> EvExpr
forall b. Type -> Expr b
Type Type
ct_l_ty
      , Type -> EvExpr
forall b. Type -> Expr b
Type Type
ct_r_ty
      , [Var] -> EvExpr -> EvExpr
mkCoreLams [ Var
r, Var
a, Var
b ]
        ( EvExpr -> [EvExpr] -> EvExpr
mkCoreApps ( Var -> EvExpr
forall b. Var -> Expr b
Var Var
a )
          [ PluginDefs -> Type -> [Coercion] -> Bool -> EvExpr
sat_co_expr PluginDefs
defs Type
ct_l_ty ( [Ct] -> EvExpr -> [Coercion]
usedGivenCoercions [Ct]
givens EvExpr
ct_l_evTerm ) Bool
True
          , EvExpr
ct_l_evTerm
          ]
        )
      ]

-- Evidence term for @ct_l || ct_r@ when @ct_l@ isn't satisfied, but @ct_r@ is.

--

-- dispatch =

--   \ @r

--     ( _ :: ( IsSat ct_l ~ True, ct_l ) => r )

--     ( b :: ( IsSat ct_l ~ False, IsSat ct_r ~ True, ct_r ) => r )

--   -> b ct_l_notSat_co ct_r_isSat_co ct_r_evTerm

dispatchFalseEvTerm :: PluginDefs -> [Ct] -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchFalseEvTerm :: PluginDefs
-> [Ct]
-> Type
-> Type
-> EvExpr
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
dispatchFalseEvTerm defs :: PluginDefs
defs@( PluginDefs { Class
orClass :: PluginDefs -> Class
orClass :: Class
orClass } ) [Ct]
givens Type
ct_l_ty Type
ct_r_ty EvExpr
ct_r_evExpr = do
  Name
r_name <- OccName -> TcM Name
newName ( String -> OccName
mkTyVarOcc String
"r" )
  Name
b_name <- OccName -> TcM Name
newName ( String -> OccName
mkVarOcc   String
"b" )
  let
    r, a, b :: CoreBndr
    r :: Var
r = Name -> Type -> Var
mkTyVar Name
r_name Type
liftedTypeKind
    a :: Var
a = Type -> Type -> Var
mkWildValBinder  Type
ManyTy
        ( [Type] -> Type -> Type
(() :: Constraint) => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
True, Type
ct_l_ty ] Type
r_ty )
    b :: Var
b = (() :: Constraint) => Name -> Type -> Type -> Var
Name -> Type -> Type -> Var
mkLocalId Name
b_name Type
ManyTy
        ( [Type] -> Type -> Type
(() :: Constraint) => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
False, PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_r_ty Bool
True, Type
ct_r_ty ] Type
r_ty )
    r_ty :: Type
    r_ty :: Type
r_ty = Var -> Type
mkTyVarTy Var
r
  EvTerm -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EvTerm -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm)
-> (EvExpr -> EvTerm)
-> EvExpr
-> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr (EvExpr -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm)
-> EvExpr -> IOEnv (Env TcGblEnv TcLclEnv) EvTerm
forall a b. (a -> b) -> a -> b
$
    DataCon -> [EvExpr] -> EvExpr
mkCoreConApps ( Class -> DataCon
classDataCon Class
orClass )
      [ Type -> EvExpr
forall b. Type -> Expr b
Type Type
ct_l_ty
      , Type -> EvExpr
forall b. Type -> Expr b
Type Type
ct_r_ty
      , [Var] -> EvExpr -> EvExpr
mkCoreLams [ Var
r, Var
a, Var
b ]
        ( EvExpr -> [EvExpr] -> EvExpr
mkCoreApps ( Var -> EvExpr
forall b. Var -> Expr b
Var Var
b )
          [ PluginDefs -> Type -> [Coercion] -> Bool -> EvExpr
sat_co_expr PluginDefs
defs Type
ct_l_ty [] Bool
False
            --                       ^^

            -- NB: GHC has no notion of apartness constraints, so there is

            -- no evidence we can provide for why we failed to solve a constraint.

          , PluginDefs -> Type -> [Coercion] -> Bool -> EvExpr
sat_co_expr PluginDefs
defs Type
ct_r_ty ( [Ct] -> EvExpr -> [Coercion]
usedGivenCoercions [Ct]
givens EvExpr
ct_r_evExpr ) Bool
True
          , EvExpr
ct_r_evExpr
          ]
        )
      ]

-- The type @IsSat ct ~ b@.

sat_eqTy :: PluginDefs -> Type -> Bool -> Type
sat_eqTy :: PluginDefs -> Type -> Bool -> Type
sat_eqTy ( PluginDefs { TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon :: TyCon
isSatTyCon } ) Type
ct_ty Bool
booly
  = TyCon -> [Type] -> Type
mkTyConApp TyCon
eqTyCon
      [ Type
boolTy, TyCon -> [Type] -> Type
mkTyConApp TyCon
isSatTyCon [Type
ct_ty], Type
rhs ]
  where
    rhs :: Type
    rhs :: Type
rhs = if Bool
booly then Type
tru else Type
fls

-- Construct an expression of type @IsSat ct ~ b@.

sat_co_expr :: PluginDefs -> Type -> [Coercion] -> Bool -> EvExpr
sat_co_expr :: PluginDefs -> Type -> [Coercion] -> Bool -> EvExpr
sat_co_expr ( PluginDefs { TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon :: TyCon
isSatTyCon } ) Type
ct_ty [Coercion]
deps Bool
booly
  = DataCon -> [EvExpr] -> EvExpr
mkCoreConApps DataCon
eqDataCon
      [ Type -> EvExpr
forall b. Type -> Expr b
Type Type
boolTy
      , Type -> EvExpr
forall b. Type -> Expr b
Type (Type -> EvExpr) -> Type -> EvExpr
forall a b. (a -> b) -> a -> b
$ TyCon -> [Type] -> Type
mkTyConApp TyCon
isSatTyCon [ Type
ct_ty ]
      , Type -> EvExpr
forall b. Type -> Expr b
Type Type
rhs
      , Coercion -> EvExpr
forall b. Coercion -> Expr b
Coercion (Coercion -> EvExpr) -> Coercion -> EvExpr
forall a b. (a -> b) -> a -> b
$
          String -> Role -> [Coercion] -> Type -> Type -> Coercion
mkPluginUnivCo ( String
"IfSat :" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Bool -> String
forall a. Show a => a -> String
show Bool
booly )
            Role
Nominal [Coercion]
deps
            ( TyCon -> [Type] -> Type
mkTyConApp TyCon
isSatTyCon [Type
ct_ty] ) Type
rhs
      ]
  where
    rhs :: Type
    rhs :: Type
rhs = if Bool
booly then Type
tru else Type
fls

fls, tru :: Type
fls :: Type
fls = TyCon -> Type
mkTyConTy TyCon
promotedFalseDataCon
tru :: Type
tru = TyCon -> Type
mkTyConTy TyCon
promotedTrueDataCon

-- | After filling in evidence for a constraint, compute which Givens the

-- evidence depends on.

usedGivenCoercions :: [ Ct ] -> EvExpr -> [ Coercion ]
usedGivenCoercions :: [Ct] -> EvExpr -> [Coercion]
usedGivenCoercions [Ct]
givens EvExpr
ev = (Ct -> Maybe Coercion) -> [Ct] -> [Coercion]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe Coercion
dep_cv [Ct]
givens
  where
    dep_cv :: Ct -> Maybe Coercion
    dep_cv :: Ct -> Maybe Coercion
dep_cv Ct
ct
      | ctEv :: CtEvidence
ctEv@( CtGiven {} ) <- Ct -> CtEvidence
ctEvidence Ct
ct
      , EqPred {} <- Type -> Pred
classifyPredType ( Ct -> Type
ctPred Ct
ct )
      , let v :: Var
v = CtEvidence -> Var
ctEvEvId CtEvidence
ctEv
      , Var
v Var -> VarSet -> Bool
`elemVarSet` VarSet
ev_cvs
      = Coercion -> Maybe Coercion
forall a. a -> Maybe a
Just (Coercion -> Maybe Coercion) -> Coercion -> Maybe Coercion
forall a b. (a -> b) -> a -> b
$ Var -> Coercion
mkCoVarCo Var
v
      | Bool
otherwise
      = Maybe Coercion
forall a. Maybe a
Nothing
    ev_cvs :: CoVarSet
    ev_cvs :: VarSet
ev_cvs = (Var -> Bool) -> VarSet -> VarSet
filterVarSet Var -> Bool
isCoVar (VarSet -> VarSet) -> VarSet -> VarSet
forall a b. (a -> b) -> a -> b
$ EvExpr -> VarSet
exprFreeVars EvExpr
ev

-- | Small utility wrapper around 'ctev_dest' to avoid incomplete record

-- selector warnings.

wantedEvDest :: HasDebugCallStack => CtEvidence -> TcEvDest
wantedEvDest :: (() :: Constraint) => CtEvidence -> TcEvDest
wantedEvDest ( CtWanted { ctev_dest :: CtEvidence -> TcEvDest
ctev_dest = TcEvDest
dst } ) = TcEvDest
dst
wantedEvDest g :: CtEvidence
g@( CtGiven {} ) =
  String -> SDoc -> TcEvDest
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"wantedEvDest called on CtGiven" (CtEvidence -> SDoc
forall a. Outputable a => a -> SDoc
ppr CtEvidence
g)

--------------------------------------------------------------------------------


rewriter :: PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter :: PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter defs :: PluginDefs
defs@( PluginDefs { TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon :: TyCon
isSatTyCon } )
  = [(TyCon, TcPluginRewriter)] -> UniqFM TyCon TcPluginRewriter
forall key elt. Uniquable key => [(key, elt)] -> UniqFM key elt
listToUFM [ ( TyCon
isSatTyCon, PluginDefs -> TcPluginRewriter
isSatRewriter PluginDefs
defs ) ]

isSatRewriter :: PluginDefs -> [ Ct ] -> [ Type ] -> TcPluginM Rewrite TcPluginRewriteResult
isSatRewriter :: PluginDefs -> TcPluginRewriter
isSatRewriter ( PluginDefs { TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon :: TyCon
isSatTyCon } ) [Ct]
givens [Type
ct_ty] = do
  String -> SDoc -> TcPluginM 'Rewrite ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat rewriter {" ([Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
givens SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
ct_ty)
  RewriteEnv
rewriteEnv <- TcPluginM 'Rewrite RewriteEnv
askRewriteEnv
  CtEvidence
ct_ev <- CtLoc -> Type -> TcPluginM 'Rewrite CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted ( RewriteEnv -> CtLoc
rewriteEnvCtLoc RewriteEnv
rewriteEnv ) Type
ct_ty
  let
    ct :: Ct
    ct :: Ct
ct = CtEvidence -> Ct
mkNonCanonical CtEvidence
ct_ev
  EvBindsVar
evBindsVar <- TcM EvBindsVar -> TcPluginM 'Rewrite EvBindsVar
forall a. TcM a -> TcPluginM 'Rewrite a
forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
unsafeLiftTcM TcM EvBindsVar
newTcEvBinds
  -- Start a new Solver run.

  Reduction
redn <- TcM Reduction -> TcPluginM 'Rewrite Reduction
forall a. TcM a -> TcPluginM 'Rewrite a
forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
unsafeLiftTcM (TcM Reduction -> TcPluginM 'Rewrite Reduction)
-> TcM Reduction -> TcPluginM 'Rewrite Reduction
forall a b. (a -> b) -> a -> b
$ EvBindsVar -> TcS Reduction -> TcM Reduction
forall a. EvBindsVar -> TcS a -> TcM a
runTcSWithEvBinds EvBindsVar
evBindsVar (TcS Reduction -> TcM Reduction) -> TcS Reduction -> TcM Reduction
forall a b. (a -> b) -> a -> b
$ do

    -- Add back all the Givens.

    String -> SDoc -> TcS ()
traceTcS String
"IfSat rewriter: adding Givens to the inert set" ([Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
givens)
    [Ct] -> TcS ()
solveSimpleGivens [Ct]
givens

    -- Keep track of the current solver state in order to undo any

    -- side-effects after calling 'solveSimpleWanteds' on 'ct'.

    [Var]
ct_unfilled_metas <- TcM [Var] -> TcS [Var]
forall a. TcM a -> TcS a
wrapTcS
                       (TcM [Var] -> TcS [Var]) -> TcM [Var] -> TcS [Var]
forall a b. (a -> b) -> a -> b
$ (Var -> IOEnv (Env TcGblEnv TcLclEnv) Bool) -> [Var] -> TcM [Var]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM Var -> IOEnv (Env TcGblEnv TcLclEnv) Bool
isUnfilledMetaTyVar
                       ([Var] -> TcM [Var]) -> [Var] -> TcM [Var]
forall a b. (a -> b) -> a -> b
$ Type -> [Var]
tyCoVarsOfTypeList Type
ct_ty
    TcS ()
restoreTcS <- TcS (TcS ())
getRestoreTcS

    -- Try to solve 'ct', using both Givens and top-level instances.

    WantedConstraints
residual_wc <- Cts -> TcS WantedConstraints
solveSimpleWanteds ( Ct -> Cts
forall a. a -> Bag a
unitBag Ct
ct )

    Maybe EvTerm
mb_ct_evTerm <- EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
evBindsVar (TcEvDest -> TcS (Maybe EvTerm)) -> TcEvDest -> TcS (Maybe EvTerm)
forall a b. (a -> b) -> a -> b
$ (() :: Constraint) => CtEvidence -> TcEvDest
CtEvidence -> TcEvDest
wantedEvDest CtEvidence
ct_ev

    -- Reset the solver state to before we attempted to solve 'ct',

    -- and undo any type variable unifications that happened.

    TcS ()
restoreTcS
    TcM () -> TcS ()
forall a. TcM a -> TcS a
wrapTcS (TcM () -> TcS ()) -> TcM () -> TcS ()
forall a b. (a -> b) -> a -> b
$ [Var] -> (Var -> TcM ()) -> TcM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Var]
ct_unfilled_metas \ Var
meta ->
      TcRef MetaDetails -> MetaDetails -> TcM ()
forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef ( Var -> TcRef MetaDetails
metaTyVarRef Var
meta ) MetaDetails
Flexi

    let
      is_sat :: Bool
      sat :: Type
      deps :: [ Coercion ]
      ( Bool
is_sat, Type
sat, [Coercion]
deps )
        | Just ( EvExpr EvExpr
ct_evExpr ) <- Maybe EvTerm
mb_ct_evTerm
        , WantedConstraints -> Bool
isEmptyWC WantedConstraints
residual_wc
        = ( Bool
True, TyCon -> Type
mkTyConTy TyCon
promotedTrueDataCon, [Ct] -> EvExpr -> [Coercion]
usedGivenCoercions [Ct]
givens EvExpr
ct_evExpr )
        | Bool
otherwise
        = ( Bool
False, TyCon -> Type
mkTyConTy TyCon
promotedFalseDataCon, [] )
    Reduction -> TcS Reduction
forall a. a -> TcS a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Reduction -> TcS Reduction) -> Reduction -> TcS Reduction
forall a b. (a -> b) -> a -> b
$
      String
-> Role -> [Coercion] -> TyCon -> [Type] -> Type -> Reduction
mkTyFamAppReduction ( String
"IsSat: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Bool -> String
forall a. Show a => a -> String
show Bool
is_sat )
        Role
Nominal
        [Coercion]
deps
        TyCon
isSatTyCon
        [Type
ct_ty]
        Type
sat

  String -> SDoc -> TcPluginM 'Rewrite ()
forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat rewriter }" ( Reduction -> SDoc
forall a. Outputable a => a -> SDoc
ppr Reduction
redn )
  TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult
forall a. a -> TcPluginM 'Rewrite a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult)
-> TcPluginRewriteResult
-> TcPluginM 'Rewrite TcPluginRewriteResult
forall a b. (a -> b) -> a -> b
$ Reduction -> [Ct] -> TcPluginRewriteResult
TcPluginRewriteTo Reduction
redn []

isSatRewriter PluginDefs
_ [Ct]
_ [Type]
_ = TcPluginRewriteResult -> TcPluginM 'Rewrite TcPluginRewriteResult
forall a. a -> TcPluginM 'Rewrite a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TcPluginRewriteResult
TcPluginNoRewrite