{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TupleSections #-}
module IfSat.Plugin
( plugin )
where
import Data.Maybe
( catMaybes )
#if !MIN_VERSION_ghc(9,2,0)
import Unsafe.Coerce
( unsafeCoerce )
#endif
import GHC.Plugins
( Plugin(..)
, defaultPlugin, purePlugin
)
import GHC.Data.Bag
( unitBag )
import GHC.Tc.Solver.Interact
( solveSimpleGivens, solveSimpleWanteds )
import GHC.Tc.Solver.Monad
( TcS
, getTcEvBindsMap, readTcRef, runTcS, runTcSWithEvBinds, traceTcS
#if MIN_VERSION_ghc(9,2,0)
, wrapTcS
#endif
)
import GHC.Tc.Types
( TcM )
import GHC.Tc.Types.Constraint
( isEmptyWC )
import GHC.Utils.Outputable
( (<+>), ($$), empty, text, vcat )
import GHC.TcPlugin.API
import GHC.TcPlugin.API.Internal
( unsafeLiftTcM )
plugin :: Plugin
plugin =
defaultPlugin
{ tcPlugin = \ _args -> Just $ mkTcPlugin ifSatTcPlugin
, pluginRecompile = purePlugin
}
ifSatTcPlugin :: TcPlugin
ifSatTcPlugin =
TcPlugin
{ tcPluginInit = initPlugin
, tcPluginSolve = solver
, tcPluginRewrite = rewriter
, tcPluginStop = \ _ -> pure ()
}
data PluginDefs
= PluginDefs
{ orClass :: !Class
, isSatTyCon :: !TyCon
}
findModule :: MonadTcPlugin m => Maybe String -> String -> m Module
findModule mb_pkg modName = do
findResult <- findImportedModule ( mkModuleName modName ) ( fmap fsLit mb_pkg )
case findResult of
Found _ res -> pure res
FoundMultiple _ -> error $ "IfSat plugin: found multiple modules named " <> modName <> "."
_ -> error $ "IfSat plugin: could not find any module named " <> modName <> "."
initPlugin :: TcPluginM Init PluginDefs
initPlugin = do
ifSatModule <- findModule Nothing "Data.Constraint.If"
orClass <- tcLookupClass =<< lookupOrig ifSatModule ( mkClsOcc "||" )
isSatTyCon <- tcLookupTyCon =<< lookupOrig ifSatModule ( mkTcOcc "IsSat" )
pure $ PluginDefs { orClass, isSatTyCon }
solver :: PluginDefs -> [ Ct ] -> [ Ct ] -> TcPluginM Solve TcPluginSolveResult
solver defs givens wanteds
| null wanteds
= pure $ TcPluginOk [] []
| otherwise
= do
tcPluginTrace "IfSat solver {" (ppr givens $$ ppr wanteds)
solveds <- catMaybes <$> traverse ( solveWanted defs givens ) wanteds
tcPluginTrace "IfSat solver }" empty
pure $ TcPluginOk solveds []
solveWanted :: PluginDefs -> [ Ct ] -> Ct -> TcPluginM Solve ( Maybe ( EvTerm, Ct ) )
solveWanted defs@( PluginDefs { orClass } ) givens wanted
| ClassPred cls [ct_l_ty, ct_r_ty] <- classifyPredType ( ctPred wanted )
, cls == orClass
= do
tcPluginTrace "IfSat solver: found (||) constraint"
( ppr ct_l_ty $$ ppr ct_r_ty $$ ppr wanted )
ct_l_ev <- newWanted ( ctLoc wanted ) ct_l_ty
ct_r_ev <- newWanted ( ctLoc wanted ) ct_r_ty
let
ct_l, ct_r :: Ct
ct_l = mkNonCanonical ct_l_ev
ct_r = mkNonCanonical ct_r_ev
ct_l_ev_dest, ct_r_ev_dest :: TcEvDest
ct_l_ev_dest = ctev_dest ct_l_ev
ct_r_ev_dest = ctev_dest ct_r_ev
evBindsVar <- askEvBinds
unsafeLiftTcM $ runTcSWithEvBinds evBindsVar $ do
traceTcS "IfSat solver: adding Givens to the inert set" (ppr givens)
solveSimpleGivens givens
_ <- solveSimpleWanteds ( unitBag ct_l )
mb_ct_l_evTerm <- lookupEvTerm evBindsVar ct_l_ev_dest
mb_wanted_evTerm <- case mb_ct_l_evTerm of
Just ( EvExpr ct_l_evExpr ) -> do
traceTcS "IfSat solver: LHS constraint could be solved"
( vcat
[ text "ct_l =" <+> ppr ct_l_ty
, text "ev =" <+> ppr ct_l_evExpr
]
)
wrapTcS $ ( Just <$> dispatchTrueEvTerm defs ct_l_ty ct_r_ty ct_l_evExpr )
_ -> do
traceTcS "IfSat solver: LHS constraint could not be solved"
( text "ct_l =" <+> ppr ct_l_ty )
_ <- solveSimpleWanteds ( unitBag ct_r )
mb_ct_r_evTerm <- lookupEvTerm evBindsVar ct_r_ev_dest
case mb_ct_r_evTerm of
Just ( EvExpr ct_r_evExpr ) -> do
traceTcS "IfSat solver: RHS constraint could be solved"
( vcat
[ text "ct_r =" <+> ppr ct_r_ty
, text "ev =" <+> ppr ct_r_evExpr
]
)
wrapTcS $ ( Just <$> dispatchFalseEvTerm defs ct_l_ty ct_r_ty ct_r_evExpr )
_ -> do
traceTcS "IfSat solver: RHS constraint could not be solved"
( text "ct_r =" <+> ppr ct_r_ty )
pure Nothing
pure $ ( , wanted ) <$> mb_wanted_evTerm
| otherwise
= pure Nothing
lookupEvTerm :: EvBindsVar -> TcEvDest -> TcS ( Maybe EvTerm )
lookupEvTerm _ ( HoleDest ( CoercionHole { ch_ref = ref } ) ) = do
mb_co <- readTcRef ref
traceTcS "IfSat solver: coercion hole" ( ppr mb_co )
case mb_co of
Nothing -> pure Nothing
Just co -> pure . Just $ evCoercion co
lookupEvTerm evBindsVar ( EvVarDest ev_var ) = do
evBindsMap <- getTcEvBindsMap evBindsVar
let
mb_evBind :: Maybe EvBind
mb_evBind = lookupEvBind evBindsMap ev_var
traceTcS "IfSat solver: evidence binding" ( ppr mb_evBind )
case mb_evBind of
Nothing -> pure Nothing
Just ev_bind -> pure . Just $ eb_rhs ev_bind
dispatchTrueEvTerm :: PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchTrueEvTerm defs@( PluginDefs { orClass } ) ct_l_ty ct_r_ty ct_l_evTerm = do
r_name <- newName ( mkTyVarOcc "r" )
a_name <- newName ( mkVarOcc "a" )
let
r, a, b :: CoreBndr
r = mkTyVar r_name liftedTypeKind
a = mkLocalId a_name Many
( mkInvisFunTysMany [ sat_eqTy defs ct_l_ty True, ct_l_ty ] r_ty )
b = mkWildValBinder Many
( mkInvisFunTysMany [ sat_eqTy defs ct_l_ty False, sat_eqTy defs ct_r_ty True, ct_r_ty ] r_ty )
r_ty :: Type
r_ty = mkTyVarTy r
pure . EvExpr $
mkCoreConApps ( classDataCon orClass )
[ Type ct_l_ty
, Type ct_r_ty
, mkCoreLams [ r, a, b ]
( mkCoreApps ( Var a )
[ sat_co_expr defs ct_l_ty True
, ct_l_evTerm
]
)
]
dispatchFalseEvTerm :: PluginDefs -> Type -> Type -> EvExpr -> TcM EvTerm
dispatchFalseEvTerm defs@( PluginDefs { orClass } ) ct_l_ty ct_r_ty ct_r_evExpr = do
r_name <- newName ( mkTyVarOcc "r" )
b_name <- newName ( mkVarOcc "b" )
let
r, a, b :: CoreBndr
r = mkTyVar r_name liftedTypeKind
a = mkWildValBinder Many
( mkInvisFunTysMany [ sat_eqTy defs ct_l_ty True, ct_l_ty ] r_ty )
b = mkLocalId b_name Many
( mkInvisFunTysMany [ sat_eqTy defs ct_l_ty False, sat_eqTy defs ct_r_ty True, ct_r_ty ] r_ty )
r_ty :: Type
r_ty = mkTyVarTy r
pure . EvExpr $
mkCoreConApps ( classDataCon orClass )
[ Type ct_l_ty
, Type ct_r_ty
, mkCoreLams [ r, a, b ]
( mkCoreApps ( Var b )
[ sat_co_expr defs ct_l_ty False
, sat_co_expr defs ct_r_ty True
, ct_r_evExpr
]
)
]
sat_eqTy :: PluginDefs -> Type -> Bool -> Type
sat_eqTy ( PluginDefs { isSatTyCon } ) ct_ty booly
= mkTyConApp eqTyCon
[ boolTy, mkTyConApp isSatTyCon [ct_ty], rhs ]
where
rhs :: Type
rhs = if booly then tru else fls
sat_co_expr :: PluginDefs -> Type -> Bool -> EvExpr
sat_co_expr ( PluginDefs { isSatTyCon } ) ct_ty booly
= mkCoreConApps eqDataCon
[ Type boolTy
, Type $ mkTyConApp isSatTyCon [ ct_ty ]
, Type rhs
, Coercion $
mkPluginUnivCo ( "IfSat :" <> show booly )
Nominal
( mkTyConApp isSatTyCon [ct_ty] ) rhs
]
where
rhs :: Type
rhs = if booly then tru else fls
fls, tru :: Type
fls = mkTyConTy promotedFalseDataCon
tru = mkTyConTy promotedTrueDataCon
rewriter :: PluginDefs -> UniqFM TyCon TcPluginRewriter
rewriter defs@( PluginDefs { isSatTyCon } )
= listToUFM [ ( isSatTyCon, isSatRewriter defs ) ]
isSatRewriter :: PluginDefs -> [Ct] -> [Type] -> TcPluginM Rewrite TcPluginRewriteResult
isSatRewriter ( PluginDefs { isSatTyCon } ) givens [ct_ty] = do
tcPluginTrace "IfSat rewriter {" (ppr givens $$ ppr ct_ty)
rewriteEnv <- askRewriteEnv
ct_ev <- newWanted ( rewriteEnvCtLoc rewriteEnv ) ct_ty
let
ct :: Ct
ct = mkNonCanonical ct_ev
( redn, _ ) <- unsafeLiftTcM $ runTcS $ do
traceTcS "IfSat rewriter: adding Givens to the inert set" (ppr givens)
solveSimpleGivens givens
residual_wc <- solveSimpleWanteds ( unitBag ct )
let
is_sat :: Bool
is_sat = isEmptyWC residual_wc
sat :: Type
sat
| is_sat
= mkTyConTy promotedTrueDataCon
| otherwise
= mkTyConTy promotedFalseDataCon
pure $ mkTyFamAppReduction ( "IsSat: " <> show is_sat ) Nominal isSatTyCon [ct_ty] sat
tcPluginTrace "IfSat rewriter }" ( ppr redn )
pure $ TcPluginRewriteTo redn []
isSatRewriter _ _ _ = pure TcPluginNoRewrite
#if !MIN_VERSION_ghc(9,2,0)
wrapTcS :: TcM a -> TcS a
wrapTcS = unsafeCoerce const
#endif