{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TupleSections #-}

module IfSat.Plugin
  ( plugin )
  where

-- base

import Data.Maybe
  ( catMaybes )
#if !MIN_VERSION_ghc(9,2,0)
import Unsafe.Coerce
  ( unsafeCoerce )
#endif

-- ghc

import GHC.Plugins
  hiding ( TcPlugin, (<>) )
import GHC.Data.Bag
  ( unitBag )
import GHC.Tc.Solver.Interact
  ( solveSimpleGivens, solveSimpleWanteds )
import GHC.Tc.Solver.Monad
  ( TcS
  , getTcEvBindsMap, readTcRef, runTcS, runTcSWithEvBinds, traceTcS
#if MIN_VERSION_ghc(9,2,0)
  , wrapTcS
#endif
  )
import GHC.Tc.Types
  ( TcM )
import GHC.Tc.Types.Constraint
  ( isEmptyWC )

-- ghc-tcplugin-api

import GHC.TcPlugin.API
import GHC.TcPlugin.API.Internal
  ( unsafeLiftTcM )

--------------------------------------------------------------------------------

-- Plugin definition.


-- | A type-checking plugin that solves @ct_l || ct_r@ constraints.

-- This allows users to branch on whether @ct_l@ is satisfied.

--

-- To use this plugin, add @{-# OPTIONS_GHC -fplugin=IfSat.Plugin #-}@

-- to your module header.

--

-- A @ct_l || ct_r@  instance is solved by trying to solve @ct_l@:

--

--   - if solving succeeds, the 'Data.Constraint.If.dispatch' function will

--     pick the first branch,

--   - otherwise, 'Data.Constraint.If.dispatch' will pick the second branch.

--

-- This means that the branch selection occurs precisely at the moment

-- at which we solve the @ct_l || ct_r@  constraint.

-- See the documentation of 'Data.Constraint.If.dispatch' for more information.

plugin :: Plugin
plugin :: Plugin
plugin =
  Plugin
defaultPlugin
    { tcPlugin :: TcPlugin
tcPlugin        = \ [String]
_args -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ TcPlugin -> TcPlugin
mkTcPlugin TcPlugin
ifSatTcPlugin
    , pluginRecompile :: [String] -> IO PluginRecompile
pluginRecompile = [String] -> IO PluginRecompile
purePlugin
    }

ifSatTcPlugin :: TcPlugin
ifSatTcPlugin :: TcPlugin
ifSatTcPlugin =
  TcPlugin
    { tcPluginInit :: TcPluginM 'Init PluginDefs
tcPluginInit    = TcPluginM 'Init PluginDefs
initPlugin
    , tcPluginSolve :: PluginDefs -> TcPluginSolver
tcPluginSolve   = PluginDefs -> TcPluginSolver
solver
    , tcPluginRewrite :: PluginDefs -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter
    , tcPluginStop :: PluginDefs -> TcPluginM 'Stop ()
tcPluginStop    = \ PluginDefs
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    }

--------------------------------------------------------------------------------

-- Plugin initialisation.


data PluginDefs
  = PluginDefs
    { PluginDefs -> Class
orClass    :: !Class
    , PluginDefs -> TyCon
isSatTyCon :: !TyCon
    }

findModule :: MonadTcPlugin m => String -> m Module
findModule :: forall (m :: * -> *). MonadTcPlugin m => String -> m Module
findModule String
modName = do
  FindResult
findResult <- forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> PkgQual -> m FindResult
findImportedModule ( String -> ModuleName
mkModuleName String
modName ) PkgQual
NoPkgQual
  case FindResult
findResult of
    Found ModLocation
_ Module
res     -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Module
res
    FoundMultiple [(Module, ModuleOrigin)]
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"IfSat plugin: found multiple modules named " forall a. Semigroup a => a -> a -> a
<> String
modName forall a. Semigroup a => a -> a -> a
<> String
"."
    FindResult
_               -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"IfSat plugin: could not find any module named " forall a. Semigroup a => a -> a -> a
<> String
modName forall a. Semigroup a => a -> a -> a
<> String
"."

initPlugin :: TcPluginM Init PluginDefs
initPlugin :: TcPluginM 'Init PluginDefs
initPlugin = do
  Module
ifSatModule <- forall (m :: * -> *). MonadTcPlugin m => String -> m Module
findModule String
"Data.Constraint.If"
  Class
orClass     <- forall (m :: * -> *). MonadTcPlugin m => Name -> m Class
tcLookupClass forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadTcPlugin m =>
Module -> OccName -> m Name
lookupOrig Module
ifSatModule ( String -> OccName
mkClsOcc String
"||"    )
  TyCon
isSatTyCon  <- forall (m :: * -> *). MonadTcPlugin m => Name -> m TyCon
tcLookupTyCon forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadTcPlugin m =>
Module -> OccName -> m Name
lookupOrig Module
ifSatModule ( String -> OccName
mkTcOcc  String
"IsSat" )
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PluginDefs { Class
orClass :: Class
orClass :: Class
orClass, TyCon
isSatTyCon :: TyCon
isSatTyCon :: TyCon
isSatTyCon }

--------------------------------------------------------------------------------

-- Constraint solving.


solver :: PluginDefs -> [ Ct ] -> [ Ct ] -> TcPluginM Solve TcPluginSolveResult
solver :: PluginDefs -> TcPluginSolver
solver PluginDefs
defs [Ct]
givens [Ct]
wanteds
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Ct]
wanteds
  = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []
  | Bool
otherwise
  = do
      forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat solver {" (forall a. Outputable a => a -> SDoc
ppr [Ct]
givens SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr [Ct]
wanteds)
      [(EvTerm, Ct)]
solveds <- forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ( PluginDefs -> [Ct] -> Ct -> TcPluginM 'Solve (Maybe (EvTerm, Ct))
solveWanted PluginDefs
defs [Ct]
givens ) [Ct]
wanteds
      forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat solver }" SDoc
empty
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solveds []

solveWanted :: PluginDefs -> [ Ct ] -> Ct -> TcPluginM Solve ( Maybe ( EvTerm, Ct ) )
solveWanted :: PluginDefs -> [Ct] -> Ct -> TcPluginM 'Solve (Maybe (EvTerm, Ct))
solveWanted defs :: PluginDefs
defs@( PluginDefs { Class
orClass :: Class
orClass :: PluginDefs -> Class
orClass } ) [Ct]
givens Ct
wanted
  | ClassPred Class
cls [Type
ct_l_ty, Type
ct_r_ty] <- Type -> Pred
classifyPredType ( Ct -> Type
ctPred Ct
wanted )
  , Class
cls forall a. Eq a => a -> a -> Bool
== Class
orClass
  = do
    forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat solver: found (||) constraint"
      ( forall a. Outputable a => a -> SDoc
ppr Type
ct_l_ty SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Type
ct_r_ty SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Ct
wanted )
    CtEvidence
ct_l_ev <- forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted ( Ct -> CtLoc
ctLoc Ct
wanted ) Type
ct_l_ty
    CtEvidence
ct_r_ev <- forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted ( Ct -> CtLoc
ctLoc Ct
wanted ) Type
ct_r_ty
    let
      ct_l, ct_r :: Ct
      ct_l :: Ct
ct_l = CtEvidence -> Ct
mkNonCanonical CtEvidence
ct_l_ev
      ct_r :: Ct
ct_r = CtEvidence -> Ct
mkNonCanonical CtEvidence
ct_r_ev
      ct_l_ev_dest, ct_r_ev_dest :: TcEvDest
      ct_l_ev_dest :: TcEvDest
ct_l_ev_dest = CtEvidence -> TcEvDest
ctev_dest CtEvidence
ct_l_ev
      ct_r_ev_dest :: TcEvDest
ct_r_ev_dest = CtEvidence -> TcEvDest
ctev_dest CtEvidence
ct_r_ev
    EvBindsVar
evBindsVar <- TcPluginM 'Solve EvBindsVar
askEvBinds
    -- Start a new Solver run.

    forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
unsafeLiftTcM forall a b. (a -> b) -> a -> b
$ forall a. EvBindsVar -> TcS a -> TcM a
runTcSWithEvBinds EvBindsVar
evBindsVar forall a b. (a -> b) -> a -> b
$ do
      -- Add back all the Givens.

      String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: adding Givens to the inert set" (forall a. Outputable a => a -> SDoc
ppr [Ct]
givens)
      [Ct] -> TcS ()
solveSimpleGivens [Ct]
givens
      -- Try to solve 'ct_l', using both Givens and top-level instances.

      WantedConstraints
_ <- Cts -> TcS WantedConstraints
solveSimpleWanteds ( forall a. a -> Bag a
unitBag Ct
ct_l )
      -- Now look up whether GHC has managed to produce evidence for 'ct_l'.

      Maybe EvTerm
mb_ct_l_evTerm <- EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
evBindsVar TcEvDest
ct_l_ev_dest
      Maybe EvTerm
mb_wanted_evTerm <- case Maybe EvTerm
mb_ct_l_evTerm of
        Just ( EvExpr EvExpr
ct_l_evExpr ) -> do
          -- We've managed to solve 'ct_l': use the evidence and take the 'True' branch.

          String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: LHS constraint could be solved"
            ( [SDoc] -> SDoc
vcat
              [ String -> SDoc
text String
"ct_l =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
ct_l_ty
              , String -> SDoc
text String
"ev   =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr EvExpr
ct_l_evExpr
              ]
            )
          forall a. TcM a -> TcS a
wrapTcS forall a b. (a -> b) -> a -> b
$ ( forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchTrueEvTerm PluginDefs
defs Type
ct_l_ty Type
ct_r_ty EvExpr
ct_l_evExpr )
        Maybe EvTerm
_ -> do
          -- We couldn't solve 'ct_l': this means we must solve 'ct_r',

          -- to provide evidence needed for the 'False' branch.

          String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: LHS constraint could not be solved"
            ( String -> SDoc
text String
"ct_l =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
ct_l_ty )
          -- Try to solve 'ct_r', using both Givens and top-level instances.

          WantedConstraints
_ <- Cts -> TcS WantedConstraints
solveSimpleWanteds ( forall a. a -> Bag a
unitBag Ct
ct_r )
          Maybe EvTerm
mb_ct_r_evTerm <- EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
evBindsVar TcEvDest
ct_r_ev_dest
          case Maybe EvTerm
mb_ct_r_evTerm of
            Just ( EvExpr EvExpr
ct_r_evExpr ) -> do
              -- We've managed to solve 'ct_r': use the evidence and take the 'False' branch.

              String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: RHS constraint could be solved"
                ( [SDoc] -> SDoc
vcat
                  [ String -> SDoc
text String
"ct_r =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
ct_r_ty
                  , String -> SDoc
text String
"ev   =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr EvExpr
ct_r_evExpr
                  ]
                )
              forall a. TcM a -> TcS a
wrapTcS forall a b. (a -> b) -> a -> b
$ ( forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchFalseEvTerm PluginDefs
defs Type
ct_l_ty Type
ct_r_ty EvExpr
ct_r_evExpr )
            Maybe EvTerm
_ -> do
              -- We could solve neither 'ct_l' not 'ct_r'.

              -- This means we can't solve the disjunction constraint.

              String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: RHS constraint could not be solved"
                ( String -> SDoc
text String
"ct_r =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
ct_r_ty )
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ( , Ct
wanted ) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe EvTerm
mb_wanted_evTerm
  | Bool
otherwise
  = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

-- | Look up whether a 'TcEvDest' has been filled with evidence.

lookupEvTerm :: EvBindsVar -> TcEvDest -> TcS ( Maybe EvTerm )
lookupEvTerm :: EvBindsVar -> TcEvDest -> TcS (Maybe EvTerm)
lookupEvTerm EvBindsVar
_ ( HoleDest ( CoercionHole { ch_ref :: CoercionHole -> IORef (Maybe Coercion)
ch_ref = IORef (Maybe Coercion)
ref } ) ) = do
  Maybe Coercion
mb_co <- forall a. TcRef a -> TcS a
readTcRef IORef (Maybe Coercion)
ref
  String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: coercion hole" ( forall a. Outputable a => a -> SDoc
ppr Maybe Coercion
mb_co )
  case Maybe Coercion
mb_co of
    Maybe Coercion
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    Just Coercion
co -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Coercion -> EvTerm
evCoercion Coercion
co
lookupEvTerm EvBindsVar
evBindsVar ( EvVarDest Var
ev_var ) = do
  EvBindMap
evBindsMap <- EvBindsVar -> TcS EvBindMap
getTcEvBindsMap EvBindsVar
evBindsVar
  let
    mb_evBind :: Maybe EvBind
    mb_evBind :: Maybe EvBind
mb_evBind = EvBindMap -> Var -> Maybe EvBind
lookupEvBind EvBindMap
evBindsMap Var
ev_var
  String -> SDoc -> TcS ()
traceTcS String
"IfSat solver: evidence binding" ( forall a. Outputable a => a -> SDoc
ppr Maybe EvBind
mb_evBind )
  case Maybe EvBind
mb_evBind of
    Maybe EvBind
Nothing      -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    Just EvBind
ev_bind -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ EvBind -> EvTerm
eb_rhs EvBind
ev_bind

-- Evidence term for @ct_l || ct_r@ when @ct_l@ is satisfied.

--

-- dispatch =

--   \ @r

--     ( a :: ( IsSat ct_l ~ True, ct_l ) => r )

--     ( _ :: ( IsSat ct_l ~ False, IsSat ct_r ~ True, ct_r ) => r )

--   -> a ct_l_isSat_co ct_l_evTerm

dispatchTrueEvTerm :: PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchTrueEvTerm :: PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchTrueEvTerm defs :: PluginDefs
defs@( PluginDefs { Class
orClass :: Class
orClass :: PluginDefs -> Class
orClass } ) Type
ct_l_ty Type
ct_r_ty EvExpr
ct_l_evTerm = do
  Name
r_name <- OccName -> TcM Name
newName ( String -> OccName
mkTyVarOcc String
"r" )
  Name
a_name <- OccName -> TcM Name
newName ( String -> OccName
mkVarOcc   String
"a" )
  let
    r, a, b :: CoreBndr
    r :: Var
r = Name -> Type -> Var
mkTyVar Name
r_name Type
liftedTypeKind
    a :: Var
a = HasDebugCallStack => Name -> Type -> Type -> Var
mkLocalId Name
a_name Type
ManyTy
        ( HasDebugCallStack => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
True, Type
ct_l_ty ] Type
r_ty )
    b :: Var
b = Type -> Type -> Var
mkWildValBinder  Type
ManyTy
        ( HasDebugCallStack => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
False, PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_r_ty Bool
True, Type
ct_r_ty ] Type
r_ty )
    r_ty :: Type
    r_ty :: Type
r_ty = Var -> Type
mkTyVarTy Var
r
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr forall a b. (a -> b) -> a -> b
$
    DataCon -> [EvExpr] -> EvExpr
mkCoreConApps ( Class -> DataCon
classDataCon Class
orClass )
      [ forall b. Type -> Expr b
Type Type
ct_l_ty
      , forall b. Type -> Expr b
Type Type
ct_r_ty
      , [Var] -> EvExpr -> EvExpr
mkCoreLams [ Var
r, Var
a, Var
b ]
        ( EvExpr -> [EvExpr] -> EvExpr
mkCoreApps ( forall b. Var -> Expr b
Var Var
a )
          [ PluginDefs -> Type -> Bool -> EvExpr
sat_co_expr PluginDefs
defs Type
ct_l_ty Bool
True
          , EvExpr
ct_l_evTerm
          ]
        )
      ]

-- Evidence term for @ct_l || ct_r@ when @ct_l@ isn't satisfied, but @ct_r@ is.

--

-- dispatch =

--   \ @r

--     ( _ :: ( IsSat ct_l ~ True, ct_l ) => r )

--     ( b :: ( IsSat ct_l ~ False, IsSat ct_r ~ True, ct_r ) => r )

--   -> b ct_l_notSat_co ct_r_isSat_co ct_r_evTerm

dispatchFalseEvTerm :: PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchFalseEvTerm :: PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchFalseEvTerm defs :: PluginDefs
defs@( PluginDefs { Class
orClass :: Class
orClass :: PluginDefs -> Class
orClass } ) Type
ct_l_ty Type
ct_r_ty EvExpr
ct_r_evExpr = do
  Name
r_name <- OccName -> TcM Name
newName ( String -> OccName
mkTyVarOcc String
"r" )
  Name
b_name <- OccName -> TcM Name
newName ( String -> OccName
mkVarOcc   String
"b" )
  let
    r, a, b :: CoreBndr
    r :: Var
r = Name -> Type -> Var
mkTyVar Name
r_name Type
liftedTypeKind
    a :: Var
a = Type -> Type -> Var
mkWildValBinder  Type
ManyTy
        ( HasDebugCallStack => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
True, Type
ct_l_ty ] Type
r_ty )
    b :: Var
b = HasDebugCallStack => Name -> Type -> Type -> Var
mkLocalId Name
b_name Type
ManyTy
        ( HasDebugCallStack => [Type] -> Type -> Type
mkInvisFunTys [ PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_l_ty Bool
False, PluginDefs -> Type -> Bool -> Type
sat_eqTy PluginDefs
defs Type
ct_r_ty Bool
True, Type
ct_r_ty ] Type
r_ty )
    r_ty :: Type
    r_ty :: Type
r_ty = Var -> Type
mkTyVarTy Var
r
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr forall a b. (a -> b) -> a -> b
$
    DataCon -> [EvExpr] -> EvExpr
mkCoreConApps ( Class -> DataCon
classDataCon Class
orClass )
      [ forall b. Type -> Expr b
Type Type
ct_l_ty
      , forall b. Type -> Expr b
Type Type
ct_r_ty
      , [Var] -> EvExpr -> EvExpr
mkCoreLams [ Var
r, Var
a, Var
b ]
        ( EvExpr -> [EvExpr] -> EvExpr
mkCoreApps ( forall b. Var -> Expr b
Var Var
b )
          [ PluginDefs -> Type -> Bool -> EvExpr
sat_co_expr PluginDefs
defs Type
ct_l_ty Bool
False
          , PluginDefs -> Type -> Bool -> EvExpr
sat_co_expr PluginDefs
defs Type
ct_r_ty Bool
True
          , EvExpr
ct_r_evExpr
          ]
        )
      ]

-- @ sat_eqTy defs ct_ty b @ represents the type @ IsSat ct ~ b @.

sat_eqTy :: PluginDefs -> Type -> Bool -> Type
sat_eqTy :: PluginDefs -> Type -> Bool -> Type
sat_eqTy ( PluginDefs { TyCon
isSatTyCon :: TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon } ) Type
ct_ty Bool
booly
  = TyCon -> [Type] -> Type
mkTyConApp TyCon
eqTyCon
      [ Type
boolTy, TyCon -> [Type] -> Type
mkTyConApp TyCon
isSatTyCon [Type
ct_ty], Type
rhs ]
  where
    rhs :: Type
    rhs :: Type
rhs = if Bool
booly then Type
tru else Type
fls

-- @ sat_co_expr defs ct_ty b @ is an expression of type @ IsSat ct ~ b @.

sat_co_expr :: PluginDefs -> Type -> Bool -> EvExpr
sat_co_expr :: PluginDefs -> Type -> Bool -> EvExpr
sat_co_expr ( PluginDefs { TyCon
isSatTyCon :: TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon } ) Type
ct_ty Bool
booly
  = DataCon -> [EvExpr] -> EvExpr
mkCoreConApps DataCon
eqDataCon
      [ forall b. Type -> Expr b
Type Type
boolTy
      , forall b. Type -> Expr b
Type forall a b. (a -> b) -> a -> b
$ TyCon -> [Type] -> Type
mkTyConApp TyCon
isSatTyCon [ Type
ct_ty ]
      , forall b. Type -> Expr b
Type Type
rhs
      , forall b. Coercion -> Expr b
Coercion forall a b. (a -> b) -> a -> b
$
          String -> Role -> Type -> Type -> Coercion
mkPluginUnivCo ( String
"IfSat :" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Bool
booly )
          Role
Nominal
          ( TyCon -> [Type] -> Type
mkTyConApp TyCon
isSatTyCon [Type
ct_ty] ) Type
rhs
      ]
  where
    rhs :: Type
    rhs :: Type
rhs = if Bool
booly then Type
tru else Type
fls

fls, tru :: Type
fls :: Type
fls = TyCon -> Type
mkTyConTy TyCon
promotedFalseDataCon
tru :: Type
tru = TyCon -> Type
mkTyConTy TyCon
promotedTrueDataCon

--------------------------------------------------------------------------------


rewriter :: PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter :: PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter defs :: PluginDefs
defs@( PluginDefs { TyCon
isSatTyCon :: TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon } )
  = forall key elt. Uniquable key => [(key, elt)] -> UniqFM key elt
listToUFM [ ( TyCon
isSatTyCon, PluginDefs -> TcPluginRewriter
isSatRewriter PluginDefs
defs ) ]

isSatRewriter :: PluginDefs -> [Ct] -> [Type] -> TcPluginM Rewrite TcPluginRewriteResult
isSatRewriter :: PluginDefs -> TcPluginRewriter
isSatRewriter ( PluginDefs { TyCon
isSatTyCon :: TyCon
isSatTyCon :: PluginDefs -> TyCon
isSatTyCon } ) [Ct]
givens [Type
ct_ty] = do
  forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat rewriter {" (forall a. Outputable a => a -> SDoc
ppr [Ct]
givens SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Type
ct_ty)
  RewriteEnv
rewriteEnv <- TcPluginM 'Rewrite RewriteEnv
askRewriteEnv
  CtEvidence
ct_ev <- forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted ( RewriteEnv -> CtLoc
rewriteEnvCtLoc RewriteEnv
rewriteEnv ) Type
ct_ty
  let
    ct :: Ct
    ct :: Ct
ct = CtEvidence -> Ct
mkNonCanonical CtEvidence
ct_ev
  -- Start a new Solver run.

  ( Reduction
redn, EvBindMap
_ ) <- forall (m :: * -> *) a. MonadTcPlugin m => TcM a -> m a
unsafeLiftTcM forall a b. (a -> b) -> a -> b
$ forall a. TcS a -> TcM (a, EvBindMap)
runTcS forall a b. (a -> b) -> a -> b
$ do
    -- Add back all the Givens.

    String -> SDoc -> TcS ()
traceTcS String
"IfSat rewriter: adding Givens to the inert set" (forall a. Outputable a => a -> SDoc
ppr [Ct]
givens)
    [Ct] -> TcS ()
solveSimpleGivens [Ct]
givens
    -- Try to solve 'ct', using both Givens and top-level instances.

    WantedConstraints
residual_wc <- Cts -> TcS WantedConstraints
solveSimpleWanteds ( forall a. a -> Bag a
unitBag Ct
ct )
    -- When there are residual Wanteds, we couldn't solve the constraint.

    let
      is_sat :: Bool
      is_sat :: Bool
is_sat = WantedConstraints -> Bool
isEmptyWC WantedConstraints
residual_wc
      sat :: Type
      sat :: Type
sat
        | Bool
is_sat
        = TyCon -> Type
mkTyConTy TyCon
promotedTrueDataCon
        | Bool
otherwise
        = TyCon -> Type
mkTyConTy TyCon
promotedFalseDataCon
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ String -> Role -> TyCon -> [Type] -> Type -> Reduction
mkTyFamAppReduction ( String
"IsSat: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Bool
is_sat ) Role
Nominal TyCon
isSatTyCon [Type
ct_ty] Type
sat
  forall (m :: * -> *). MonadTcPlugin m => String -> SDoc -> m ()
tcPluginTrace String
"IfSat rewriter }" ( forall a. Outputable a => a -> SDoc
ppr Reduction
redn )
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Reduction -> [Ct] -> TcPluginRewriteResult
TcPluginRewriteTo Reduction
redn []
isSatRewriter PluginDefs
_ [Ct]
_ [Type]
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure TcPluginRewriteResult
TcPluginNoRewrite

--------------------------------------------------------------------------------


#if !MIN_VERSION_ghc(9,2,0)
wrapTcS :: TcM a -> TcS a
wrapTcS = unsafeCoerce const
#endif