{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module GHC.TcPlugin.API.Internal
(
MonadTcPlugin(..), MonadTcPluginWork
, unsafeLiftThroughTcM
, TcPlugin(..), TcPluginStage(..)
, TcPluginSolver
, TcPluginM(..)
, TcPluginErrorMessage(..)
, TcPluginRewriter
, askRewriteEnv
, askDeriveds
, askEvBinds
, mkTcPlugin
, mkTcPluginErrorTy
)
where
import Data.Coerce
( Coercible )
import Data.Kind
( Constraint, Type )
import GHC.TypeLits
( TypeError, ErrorMessage(..) )
import Control.Monad.Trans.Reader
( ReaderT(..) )
import qualified GHC.Builtin.Names
as GHC.TypeLits
( errorMessageTypeErrorFamName
, typeErrorTextDataConName
, typeErrorAppendDataConName
, typeErrorVAppendDataConName
, typeErrorShowTypeDataConName
)
import qualified GHC.Builtin.Types
as GHC
( constraintKind )
import qualified GHC.Core.DataCon
as GHC
( promoteDataCon )
import qualified GHC.Core.TyCon
as GHC
( TyCon )
import qualified GHC.Core.TyCo.Rep
as GHC
( PredType, Type(..), TyLit(..) )
import qualified GHC.Core.Type
as GHC
( mkTyConApp, tcTypeKind )
import qualified GHC.Data.FastString
as GHC
( fsLit )
import qualified GHC.Tc.Plugin
as GHC
( tcLookupDataCon, tcLookupTyCon )
import qualified GHC.Tc.Types
as GHC
( TcM, TcPlugin(..), TcPluginM
, TcPluginSolver
#ifdef HAS_REWRITING
, TcPluginRewriter
#else
, getEvBindsTcPluginM
#endif
, runTcPluginM, unsafeTcPluginTcM
)
#ifdef HAS_REWRITING
import GHC.Tc.Types
( TcPluginSolveResult
, TcPluginRewriteResult
, RewriteEnv
)
#endif
import qualified GHC.Tc.Types.Constraint
as GHC
( Ct )
import qualified GHC.Tc.Types.Evidence
as GHC
( EvBindsVar )
import qualified GHC.Types.Unique.FM
as GHC
( UniqFM )
#ifndef HAS_REWRITING
import GHC.TcPlugin.API.Internal.Shim
( TcPluginSolveResult, TcPluginRewriteResult(..)
, RewriteEnv
, shimRewriter
)
#endif
data TcPluginStage
= Init
| Solve
| Rewrite
| Stop
type TcPluginSolver
= [GHC.Ct]
-> [GHC.Ct]
-> TcPluginM Solve TcPluginSolveResult
type TcPluginRewriter
= [GHC.Ct]
-> [GHC.Type]
-> TcPluginM Rewrite TcPluginRewriteResult
data TcPlugin = forall s. TcPlugin
{ tcPluginInit :: TcPluginM Init s
, tcPluginSolve :: s -> TcPluginSolver
, tcPluginRewrite
:: s -> GHC.UniqFM
#if MIN_VERSION_ghc(9,0,0)
GHC.TyCon
#endif
TcPluginRewriter
, tcPluginStop :: s -> TcPluginM Stop ()
}
type TcPluginM :: TcPluginStage -> ( Type -> Type )
data family TcPluginM s
newtype instance TcPluginM Init a =
TcPluginInitM { tcPluginInitM :: GHC.TcPluginM a }
deriving newtype ( Functor, Applicative, Monad )
#ifdef HAS_DERIVEDS
newtype instance TcPluginM Solve a =
TcPluginSolveM { tcPluginSolveM :: BuiltinDefs -> GHC.EvBindsVar -> [GHC.Ct] -> GHC.TcPluginM a }
deriving ( Functor, Applicative, Monad )
via ( ReaderT BuiltinDefs ( ReaderT GHC.EvBindsVar ( ReaderT [GHC.Ct] GHC.TcPluginM ) ) )
#else
newtype instance TcPluginM Solve a =
TcPluginSolveM { tcPluginSolveM :: BuiltinDefs -> GHC.EvBindsVar -> GHC.TcPluginM a }
deriving ( Functor, Applicative, Monad )
via ( ReaderT BuiltinDefs ( ReaderT GHC.EvBindsVar GHC.TcPluginM ) )
#endif
newtype instance TcPluginM Rewrite a =
TcPluginRewriteM { tcPluginRewriteM :: BuiltinDefs -> RewriteEnv -> GHC.TcPluginM a }
deriving ( Functor, Applicative, Monad )
via ( ReaderT BuiltinDefs ( ReaderT RewriteEnv GHC.TcPluginM ) )
newtype instance TcPluginM Stop a =
TcPluginStopM { tcPluginStopM :: GHC.TcPluginM a }
deriving newtype ( Functor, Applicative, Monad )
askEvBinds :: TcPluginM Solve GHC.EvBindsVar
askEvBinds = TcPluginSolveM
\ _defs
evBinds
#ifdef HAS_DERIVEDS
_deriveds
#endif
-> pure evBinds
askDeriveds :: TcPluginM Solve [GHC.Ct]
askDeriveds =
#ifdef HAS_DERIVEDS
TcPluginSolveM \ _defs _evBinds deriveds -> pure deriveds
#else
pure []
#endif
askRewriteEnv :: TcPluginM Rewrite RewriteEnv
askRewriteEnv = TcPluginRewriteM ( \ _ rewriteEnv -> pure rewriteEnv )
type MonadTcPlugin :: ( Type -> Type ) -> Constraint
class ( Monad m, ( forall x y. Coercible x y => Coercible (m x) (m y) ) ) => MonadTcPlugin m where
{-# MINIMAL liftTcPluginM, unsafeWithRunInTcM #-}
liftTcPluginM :: GHC.TcPluginM a -> m a
unsafeLiftTcM :: GHC.TcM a -> m a
unsafeLiftTcM = liftTcPluginM . GHC.unsafeTcPluginTcM
unsafeWithRunInTcM :: ( ( forall a. m a -> GHC.TcM a ) -> GHC.TcM b ) -> m b
instance MonadTcPlugin ( TcPluginM Init ) where
liftTcPluginM = TcPluginInitM
unsafeWithRunInTcM runInTcM
= unsafeLiftTcM $ runInTcM
#ifdef HAS_REWRITING
( GHC.runTcPluginM . tcPluginInitM )
#else
( ( `GHC.runTcPluginM` ( error "tcPluginInit: cannot access EvBindsVar" ) ) . tcPluginInitM )
#endif
instance MonadTcPlugin ( TcPluginM Solve ) where
liftTcPluginM = TcPluginSolveM
#ifdef HAS_DERIVEDS
. ( \ ma _defs _evBinds _deriveds -> ma )
#else
. ( \ ma _defs _evBinds -> ma )
#endif
unsafeWithRunInTcM runInTcM
= TcPluginSolveM
\ builtinDefs
evBinds
#ifdef HAS_DERIVEDS
deriveds
#endif
->
GHC.unsafeTcPluginTcM $ runInTcM
#ifdef HAS_REWRITING
( GHC.runTcPluginM
#ifdef HAS_DERIVEDS
. ( \ f -> f builtinDefs evBinds deriveds )
#else
. ( \ f -> f builtinDefs evBinds )
#endif
. tcPluginSolveM )
#else
( ( `GHC.runTcPluginM` evBinds )
. ( \ f -> f builtinDefs evBinds deriveds )
. tcPluginSolveM
)
#endif
instance MonadTcPlugin ( TcPluginM Rewrite ) where
liftTcPluginM = TcPluginRewriteM . ( \ ma _ _ -> ma )
unsafeWithRunInTcM runInTcM
= TcPluginRewriteM \ builtinDefs rewriteEnv ->
GHC.unsafeTcPluginTcM $ runInTcM
#ifdef HAS_REWRITING
( GHC.runTcPluginM
#else
( ( `GHC.runTcPluginM` ( error "tcPluginRewrite: cannot access EvBindsVar" ) )
#endif
. ( \ f -> f builtinDefs rewriteEnv )
. tcPluginRewriteM )
instance MonadTcPlugin ( TcPluginM Stop ) where
liftTcPluginM = TcPluginStopM
unsafeWithRunInTcM runInTcM
= unsafeLiftTcM $ runInTcM
#ifdef HAS_REWRITING
( GHC.runTcPluginM . tcPluginStopM )
#else
( ( `GHC.runTcPluginM` ( error "tcPluginStop: cannot access EvBindsVar" ) ) . tcPluginStopM )
#endif
unsafeLiftThroughTcM :: MonadTcPlugin m => ( GHC.TcM a -> GHC.TcM b ) -> m a -> m b
unsafeLiftThroughTcM f ma = unsafeWithRunInTcM \ runInTcM -> f ( runInTcM ma )
mkTcPlugin
:: TcPlugin
-> GHC.TcPlugin
mkTcPlugin ( TcPlugin
{ tcPluginInit = tcPluginInit :: TcPluginM Init userDefs
, tcPluginSolve
, tcPluginRewrite
, tcPluginStop
}
) =
GHC.TcPlugin
{ GHC.tcPluginInit = adaptUserInit tcPluginInit
#ifdef HAS_REWRITING
, GHC.tcPluginSolve = adaptUserSolve tcPluginSolve
, GHC.tcPluginRewrite = adaptUserRewrite tcPluginRewrite
#else
, GHC.tcPluginSolve = adaptUserSolveAndRewrite
tcPluginSolve tcPluginRewrite
#endif
, GHC.tcPluginStop = adaptUserStop tcPluginStop
}
where
adaptUserInit :: TcPluginM Init userDefs -> GHC.TcPluginM ( TcPluginDefs userDefs )
adaptUserInit userInit = do
tcPluginBuiltinDefs <- initBuiltinDefs
tcPluginUserDefs <- tcPluginInitM userInit
pure ( TcPluginDefs { tcPluginBuiltinDefs, tcPluginUserDefs })
#ifdef HAS_REWRITING
adaptUserSolve :: ( userDefs -> TcPluginSolver )
-> TcPluginDefs userDefs
-> GHC.EvBindsVar
-> GHC.TcPluginSolver
adaptUserSolve userSolve ( TcPluginDefs { tcPluginUserDefs, tcPluginBuiltinDefs } )
evBindsVar
#ifdef HAS_DERIVEDS
= \ givens deriveds wanteds -> do
tcPluginSolveM ( userSolve tcPluginUserDefs givens wanteds )
tcPluginBuiltinDefs evBindsVar deriveds
#else
= \ givens _deriveds wanteds -> do
tcPluginSolveM ( userSolve tcPluginUserDefs givens wanteds )
tcPluginBuiltinDefs evBindsVar
#endif
adaptUserRewrite :: ( userDefs -> GHC.UniqFM GHC.TyCon TcPluginRewriter )
-> TcPluginDefs userDefs -> GHC.UniqFM GHC.TyCon GHC.TcPluginRewriter
adaptUserRewrite userRewrite ( TcPluginDefs { tcPluginUserDefs, tcPluginBuiltinDefs })
= fmap
( \ userRewriter rewriteEnv givens tys ->
tcPluginRewriteM ( userRewriter givens tys ) tcPluginBuiltinDefs rewriteEnv
)
( userRewrite tcPluginUserDefs )
#else
adaptUserSolveAndRewrite
:: ( userDefs -> TcPluginSolver )
-> ( userDefs -> GHC.UniqFM
#if MIN_VERSION_ghc(9,0,0)
GHC.TyCon
#endif
TcPluginRewriter
)
-> TcPluginDefs userDefs
-> GHC.TcPluginSolver
adaptUserSolveAndRewrite userSolve userRewrite ( TcPluginDefs { tcPluginUserDefs, tcPluginBuiltinDefs } )
= \ givens deriveds wanteds -> do
evBindsVar <- GHC.getEvBindsTcPluginM
shimRewriter
givens deriveds wanteds
( fmap
( \ userRewriter rewriteEnv gs tys ->
tcPluginRewriteM ( userRewriter gs tys )
tcPluginBuiltinDefs rewriteEnv
)
( userRewrite tcPluginUserDefs )
)
( \ gs ds ws ->
tcPluginSolveM ( userSolve tcPluginUserDefs gs ws )
tcPluginBuiltinDefs evBindsVar ds
)
#endif
adaptUserStop :: ( userDefs -> TcPluginM Stop () ) -> TcPluginDefs userDefs -> GHC.TcPluginM ()
adaptUserStop userStop ( TcPluginDefs { tcPluginUserDefs } ) =
tcPluginStopM $ userStop tcPluginUserDefs
type MonadTcPluginWork :: ( Type -> Type ) -> Constraint
class MonadTcPlugin m => MonadTcPluginWork m where
{-# MINIMAL #-}
askBuiltins :: m BuiltinDefs
askBuiltins = error "askBuiltins: no default implementation"
instance MonadTcPluginWork ( TcPluginM Solve ) where
askBuiltins = TcPluginSolveM
\ builtinDefs
_evBinds
#ifdef HAS_DERIVEDS
_deriveds
#endif
-> pure builtinDefs
instance MonadTcPluginWork ( TcPluginM Rewrite ) where
askBuiltins = TcPluginRewriteM \ builtinDefs _evBinds -> pure builtinDefs
instance TypeError ( 'Text "Cannot emit new work in 'tcPluginInit'." )
=> MonadTcPluginWork ( TcPluginM Init ) where
askBuiltins = error "Cannot emit new work in 'tcPluginInit'."
instance TypeError ( 'Text "Cannot emit new work in 'tcPluginStop'." )
=> MonadTcPluginWork ( TcPluginM Stop ) where
askBuiltins = error "Cannot emit new work in 'tcPluginStop'."
data TcPluginErrorMessage
= Txt !String
| PrintType !GHC.Type
| (:|:) !TcPluginErrorMessage !TcPluginErrorMessage
| (:-:) !TcPluginErrorMessage !TcPluginErrorMessage
infixl 5 :|:
infixl 6 :-:
mkTcPluginErrorTy :: MonadTcPluginWork m => TcPluginErrorMessage -> m GHC.PredType
mkTcPluginErrorTy msg = do
builtinDefs@( BuiltinDefs { typeErrorTyCon } ) <- askBuiltins
let
errorMsgTy :: GHC.PredType
errorMsgTy = interpretErrorMessage builtinDefs msg
pure $ GHC.mkTyConApp typeErrorTyCon [ GHC.constraintKind, errorMsgTy ]
data BuiltinDefs =
BuiltinDefs
{ typeErrorTyCon :: !GHC.TyCon
, textTyCon :: !GHC.TyCon
, showTypeTyCon :: !GHC.TyCon
, concatTyCon :: !GHC.TyCon
, vcatTyCon :: !GHC.TyCon
}
data TcPluginDefs s
= TcPluginDefs
{ tcPluginBuiltinDefs :: !BuiltinDefs
, tcPluginUserDefs :: !s
}
initBuiltinDefs :: GHC.TcPluginM BuiltinDefs
initBuiltinDefs = do
typeErrorTyCon <- GHC.tcLookupTyCon GHC.TypeLits.errorMessageTypeErrorFamName
textTyCon <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorTextDataConName
showTypeTyCon <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorShowTypeDataConName
concatTyCon <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorAppendDataConName
vcatTyCon <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorVAppendDataConName
pure ( BuiltinDefs { .. } )
interpretErrorMessage :: BuiltinDefs -> TcPluginErrorMessage -> GHC.PredType
interpretErrorMessage ( BuiltinDefs { .. } ) = go
where
go :: TcPluginErrorMessage -> GHC.PredType
go ( Txt str ) =
GHC.mkTyConApp textTyCon [ GHC.LitTy . GHC.StrTyLit . GHC.fsLit $ str ]
go ( PrintType ty ) =
GHC.mkTyConApp showTypeTyCon [ GHC.tcTypeKind ty, ty ]
go ( msg1 :|: msg2 ) =
GHC.mkTyConApp concatTyCon [ go msg1, go msg2 ]
go ( msg1 :-: msg2 ) =
GHC.mkTyConApp vcatTyCon [ go msg1, go msg2 ]