{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module GHC.TcPlugin.API.Internal
  ( 
    MonadTcPlugin(..), MonadTcPluginWork
  , unsafeLiftThroughTcM
    
  , TcPlugin(..), TcPluginStage(..)
  , TcPluginSolver
  , TcPluginM(..)
  , TcPluginErrorMessage(..)
  , TcPluginRewriter
  , askRewriteEnv
  , askDeriveds
  , askEvBinds
  , mkTcPlugin
  , mkTcPluginErrorTy
  )
  where
import Data.Coerce
  ( Coercible )
import Data.Kind
  ( Constraint, Type )
import GHC.TypeLits
  ( TypeError, ErrorMessage(..) )
import Control.Monad.Trans.Reader
  ( ReaderT(..) )
import qualified GHC.Builtin.Names
  as GHC.TypeLits
    ( errorMessageTypeErrorFamName
    , typeErrorTextDataConName
    , typeErrorAppendDataConName
    , typeErrorVAppendDataConName
    , typeErrorShowTypeDataConName
    )
import qualified GHC.Builtin.Types
  as GHC
    ( constraintKind )
import qualified GHC.Core.DataCon
  as GHC
    ( promoteDataCon )
import qualified GHC.Core.TyCon
  as GHC
    ( TyCon )
import qualified GHC.Core.TyCo.Rep
  as GHC
    ( PredType, Type(..), TyLit(..) )
import qualified GHC.Core.Type
  as GHC
    ( mkTyConApp, tcTypeKind )
import qualified GHC.Data.FastString
  as GHC
    ( fsLit )
import qualified GHC.Tc.Plugin
  as GHC
    ( tcLookupDataCon, tcLookupTyCon )
import qualified GHC.Tc.Types
  as GHC
    ( TcM, TcPlugin(..), TcPluginM
    , TcPluginSolver
#ifdef HAS_REWRITING
    , TcPluginRewriter
#else
    , getEvBindsTcPluginM
#endif
    , runTcPluginM, unsafeTcPluginTcM
    )
#ifdef HAS_REWRITING
import GHC.Tc.Types
    ( TcPluginSolveResult
    , TcPluginRewriteResult
    , RewriteEnv
    )
#endif
import qualified GHC.Tc.Types.Constraint
  as GHC
    ( Ct )
import qualified GHC.Tc.Types.Evidence
  as GHC
    ( EvBindsVar )
import qualified GHC.Types.Unique.FM
  as GHC
    ( UniqFM )
#ifndef HAS_REWRITING
import GHC.TcPlugin.API.Internal.Shim
  ( TcPluginSolveResult, TcPluginRewriteResult(..)
  , RewriteEnv
  , shimRewriter
  )
#endif
data TcPluginStage
  = Init
  | Solve
  | Rewrite
  | Stop
type TcPluginSolver
  =  [GHC.Ct] 
  -> [GHC.Ct] 
  -> TcPluginM Solve TcPluginSolveResult
type TcPluginRewriter
  =  [GHC.Ct]     
  -> [GHC.Type]   
  -> TcPluginM Rewrite TcPluginRewriteResult
data TcPlugin = forall s. TcPlugin
  { tcPluginInit :: TcPluginM Init s
      
  , tcPluginSolve :: s -> TcPluginSolver
      
      
      
      
      
      
      
      
      
      
      
      
      
  , tcPluginRewrite
      :: s -> GHC.UniqFM
#if MIN_VERSION_ghc(9,0,0)
                GHC.TyCon
#endif
                TcPluginRewriter
    
    
    
    
    
    
    
    
    
    
  , tcPluginStop :: s -> TcPluginM Stop ()
   
  }
type TcPluginM :: TcPluginStage -> ( Type -> Type )
data family TcPluginM s
newtype instance TcPluginM Init a =
  TcPluginInitM { tcPluginInitM :: GHC.TcPluginM a }
  deriving newtype ( Functor, Applicative, Monad )
#ifdef HAS_DERIVEDS
newtype instance TcPluginM Solve a =
  TcPluginSolveM { tcPluginSolveM :: BuiltinDefs -> GHC.EvBindsVar -> [GHC.Ct] -> GHC.TcPluginM a }
  deriving ( Functor, Applicative, Monad )
    via ( ReaderT BuiltinDefs ( ReaderT GHC.EvBindsVar ( ReaderT [GHC.Ct] GHC.TcPluginM ) ) )
#else
newtype instance TcPluginM Solve a =
  TcPluginSolveM { tcPluginSolveM :: BuiltinDefs -> GHC.EvBindsVar -> GHC.TcPluginM a }
  deriving ( Functor, Applicative, Monad )
    via ( ReaderT BuiltinDefs ( ReaderT GHC.EvBindsVar GHC.TcPluginM ) )
#endif
newtype instance TcPluginM Rewrite a =
  TcPluginRewriteM { tcPluginRewriteM :: BuiltinDefs -> RewriteEnv -> GHC.TcPluginM a }
  deriving ( Functor, Applicative, Monad )
    via ( ReaderT BuiltinDefs ( ReaderT RewriteEnv GHC.TcPluginM ) )
newtype instance TcPluginM Stop a =
  TcPluginStopM { tcPluginStopM :: GHC.TcPluginM a }
  deriving newtype ( Functor, Applicative, Monad )
askEvBinds :: TcPluginM Solve GHC.EvBindsVar
askEvBinds = TcPluginSolveM
  \ _defs
    evBinds
#ifdef HAS_DERIVEDS
    _deriveds
#endif
  -> pure evBinds
askDeriveds :: TcPluginM Solve [GHC.Ct]
askDeriveds =
#ifdef HAS_DERIVEDS
  TcPluginSolveM \ _defs _evBinds deriveds -> pure deriveds
#else
  pure []
#endif
askRewriteEnv :: TcPluginM Rewrite RewriteEnv
askRewriteEnv = TcPluginRewriteM ( \ _ rewriteEnv -> pure rewriteEnv )
type  MonadTcPlugin :: ( Type -> Type ) -> Constraint
class ( Monad m, ( forall x y. Coercible x y => Coercible (m x) (m y) ) ) => MonadTcPlugin m where
  {-# MINIMAL liftTcPluginM, unsafeWithRunInTcM #-}
  
  
  liftTcPluginM :: GHC.TcPluginM a -> m a
  
  unsafeLiftTcM :: GHC.TcM a -> m a
  unsafeLiftTcM = liftTcPluginM . GHC.unsafeTcPluginTcM
  
  
  
  
  
  
  unsafeWithRunInTcM :: ( ( forall a. m a -> GHC.TcM a ) -> GHC.TcM b ) -> m b
instance MonadTcPlugin ( TcPluginM Init ) where
  liftTcPluginM = TcPluginInitM
  unsafeWithRunInTcM runInTcM
    = unsafeLiftTcM $ runInTcM
#ifdef HAS_REWRITING
      ( GHC.runTcPluginM . tcPluginInitM )
#else
      ( ( `GHC.runTcPluginM` ( error "tcPluginInit: cannot access EvBindsVar" ) ) . tcPluginInitM ) 
#endif
instance MonadTcPlugin ( TcPluginM Solve ) where
  liftTcPluginM  = TcPluginSolveM
#ifdef HAS_DERIVEDS
                 . ( \ ma _defs _evBinds _deriveds -> ma )
#else
                 . ( \ ma _defs _evBinds -> ma )
#endif
  unsafeWithRunInTcM runInTcM
    = TcPluginSolveM
      \ builtinDefs
        evBinds
#ifdef HAS_DERIVEDS
        deriveds
#endif
      ->
        GHC.unsafeTcPluginTcM $ runInTcM
#ifdef HAS_REWRITING
          ( GHC.runTcPluginM
#ifdef HAS_DERIVEDS
          . ( \ f -> f builtinDefs evBinds deriveds )
#else
          . ( \ f -> f builtinDefs evBinds )
#endif
          . tcPluginSolveM )
#else
          ( ( `GHC.runTcPluginM` evBinds )
          . ( \ f -> f builtinDefs evBinds deriveds )
          . tcPluginSolveM
          )
#endif
instance MonadTcPlugin ( TcPluginM Rewrite ) where
  liftTcPluginM = TcPluginRewriteM . ( \ ma _ _ -> ma )
  unsafeWithRunInTcM runInTcM
    = TcPluginRewriteM \ builtinDefs rewriteEnv ->
      GHC.unsafeTcPluginTcM $ runInTcM
#ifdef HAS_REWRITING
        ( GHC.runTcPluginM
#else
        ( ( `GHC.runTcPluginM` ( error "tcPluginRewrite: cannot access EvBindsVar" ) )
#endif
        . ( \ f -> f builtinDefs rewriteEnv )
        . tcPluginRewriteM )
instance MonadTcPlugin ( TcPluginM Stop ) where
  liftTcPluginM = TcPluginStopM
  unsafeWithRunInTcM runInTcM
    = unsafeLiftTcM $ runInTcM
#ifdef HAS_REWRITING
      ( GHC.runTcPluginM . tcPluginStopM )
#else
      ( ( `GHC.runTcPluginM` ( error "tcPluginStop: cannot access EvBindsVar" ) ) . tcPluginStopM )
#endif
unsafeLiftThroughTcM :: MonadTcPlugin m => ( GHC.TcM a -> GHC.TcM b ) -> m a -> m b
unsafeLiftThroughTcM f ma = unsafeWithRunInTcM \ runInTcM -> f ( runInTcM ma )
mkTcPlugin
  :: TcPlugin     
  -> GHC.TcPlugin 
mkTcPlugin ( TcPlugin
              { tcPluginInit = tcPluginInit :: TcPluginM Init userDefs
              , tcPluginSolve
              , tcPluginRewrite
              , tcPluginStop
              }
           ) =
  GHC.TcPlugin
    { GHC.tcPluginInit    = adaptUserInit    tcPluginInit
#ifdef HAS_REWRITING
    , GHC.tcPluginSolve   = adaptUserSolve   tcPluginSolve
    , GHC.tcPluginRewrite = adaptUserRewrite tcPluginRewrite
#else
    , GHC.tcPluginSolve   = adaptUserSolveAndRewrite
                              tcPluginSolve tcPluginRewrite
#endif
    , GHC.tcPluginStop    = adaptUserStop    tcPluginStop
    }
  where
    adaptUserInit :: TcPluginM Init userDefs -> GHC.TcPluginM ( TcPluginDefs userDefs )
    adaptUserInit userInit = do
      tcPluginBuiltinDefs <- initBuiltinDefs
      tcPluginUserDefs    <- tcPluginInitM userInit
      pure ( TcPluginDefs { tcPluginBuiltinDefs, tcPluginUserDefs })
#ifdef HAS_REWRITING
    adaptUserSolve :: ( userDefs -> TcPluginSolver )
                   -> TcPluginDefs userDefs
                   -> GHC.EvBindsVar
                   -> GHC.TcPluginSolver
    adaptUserSolve userSolve ( TcPluginDefs { tcPluginUserDefs, tcPluginBuiltinDefs } )
     evBindsVar
#ifdef HAS_DERIVEDS
      = \ givens deriveds wanteds -> do
        tcPluginSolveM ( userSolve tcPluginUserDefs givens wanteds )
          tcPluginBuiltinDefs evBindsVar deriveds
#else
      = \ givens _deriveds wanteds -> do
        tcPluginSolveM ( userSolve tcPluginUserDefs givens wanteds )
          tcPluginBuiltinDefs evBindsVar
#endif
    adaptUserRewrite :: ( userDefs -> GHC.UniqFM GHC.TyCon TcPluginRewriter )
                     -> TcPluginDefs userDefs -> GHC.UniqFM GHC.TyCon GHC.TcPluginRewriter
    adaptUserRewrite userRewrite ( TcPluginDefs { tcPluginUserDefs, tcPluginBuiltinDefs })
      = fmap
          ( \ userRewriter rewriteEnv givens tys ->
            tcPluginRewriteM ( userRewriter givens tys ) tcPluginBuiltinDefs rewriteEnv
          )
          ( userRewrite tcPluginUserDefs )
#else
    adaptUserSolveAndRewrite
      :: ( userDefs -> TcPluginSolver )
      -> ( userDefs -> GHC.UniqFM
#if MIN_VERSION_ghc(9,0,0)
                         GHC.TyCon
#endif
                         TcPluginRewriter
         )
      -> TcPluginDefs userDefs
      -> GHC.TcPluginSolver
    adaptUserSolveAndRewrite userSolve userRewrite ( TcPluginDefs { tcPluginUserDefs, tcPluginBuiltinDefs } )
      = \ givens deriveds wanteds -> do
        evBindsVar <- GHC.getEvBindsTcPluginM
        shimRewriter
          givens deriveds wanteds
          ( fmap
              ( \ userRewriter rewriteEnv gs tys ->
                tcPluginRewriteM ( userRewriter gs tys )
                  tcPluginBuiltinDefs rewriteEnv
              )
              ( userRewrite tcPluginUserDefs )
          )
          ( \ gs ds ws ->
            tcPluginSolveM ( userSolve tcPluginUserDefs gs ws )
              tcPluginBuiltinDefs evBindsVar ds
          )
#endif
    adaptUserStop :: ( userDefs -> TcPluginM Stop () ) -> TcPluginDefs userDefs -> GHC.TcPluginM ()
    adaptUserStop userStop ( TcPluginDefs { tcPluginUserDefs } ) =
      tcPluginStopM $ userStop tcPluginUserDefs
type  MonadTcPluginWork :: ( Type -> Type ) -> Constraint
class MonadTcPlugin m => MonadTcPluginWork m where
  {-# MINIMAL #-} 
  askBuiltins :: m BuiltinDefs
  askBuiltins = error "askBuiltins: no default implementation"
instance MonadTcPluginWork ( TcPluginM Solve ) where
  askBuiltins = TcPluginSolveM
    \ builtinDefs
      _evBinds
#ifdef HAS_DERIVEDS
      _deriveds
#endif
    -> pure builtinDefs
instance MonadTcPluginWork ( TcPluginM Rewrite ) where
  askBuiltins = TcPluginRewriteM \ builtinDefs _evBinds -> pure builtinDefs
instance TypeError ( 'Text "Cannot emit new work in 'tcPluginInit'." )
      => MonadTcPluginWork ( TcPluginM Init ) where
  askBuiltins = error "Cannot emit new work in 'tcPluginInit'."
instance TypeError ( 'Text "Cannot emit new work in 'tcPluginStop'." )
      => MonadTcPluginWork ( TcPluginM Stop ) where
  askBuiltins = error "Cannot emit new work in 'tcPluginStop'."
data TcPluginErrorMessage
  = Txt !String
  
  | PrintType !GHC.Type
  
  | (:|:) !TcPluginErrorMessage !TcPluginErrorMessage
  
  | (:-:) !TcPluginErrorMessage !TcPluginErrorMessage
  
infixl 5 :|:
infixl 6 :-:
mkTcPluginErrorTy :: MonadTcPluginWork m => TcPluginErrorMessage -> m GHC.PredType
mkTcPluginErrorTy msg = do
  builtinDefs@( BuiltinDefs { typeErrorTyCon } ) <- askBuiltins
  let
    errorMsgTy :: GHC.PredType
    errorMsgTy = interpretErrorMessage builtinDefs msg
  pure $ GHC.mkTyConApp typeErrorTyCon [ GHC.constraintKind, errorMsgTy ]
data BuiltinDefs =
  BuiltinDefs
    { typeErrorTyCon :: !GHC.TyCon
    , textTyCon      :: !GHC.TyCon
    , showTypeTyCon  :: !GHC.TyCon
    , concatTyCon    :: !GHC.TyCon
    , vcatTyCon      :: !GHC.TyCon
    }
data TcPluginDefs s
  = TcPluginDefs
  { tcPluginBuiltinDefs :: !BuiltinDefs
  , tcPluginUserDefs    :: !s
  }
initBuiltinDefs :: GHC.TcPluginM BuiltinDefs
initBuiltinDefs = do
  typeErrorTyCon  <-                        GHC.tcLookupTyCon   GHC.TypeLits.errorMessageTypeErrorFamName
  textTyCon       <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorTextDataConName
  showTypeTyCon   <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorShowTypeDataConName
  concatTyCon     <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorAppendDataConName
  vcatTyCon       <- GHC.promoteDataCon <$> GHC.tcLookupDataCon GHC.TypeLits.typeErrorVAppendDataConName
  pure ( BuiltinDefs { .. } )
interpretErrorMessage :: BuiltinDefs -> TcPluginErrorMessage -> GHC.PredType
interpretErrorMessage ( BuiltinDefs { .. } ) = go
  where
    go :: TcPluginErrorMessage -> GHC.PredType
    go ( Txt str ) =
      GHC.mkTyConApp textTyCon [ GHC.LitTy . GHC.StrTyLit . GHC.fsLit $ str ]
    go ( PrintType ty ) =
      GHC.mkTyConApp showTypeTyCon [ GHC.tcTypeKind ty, ty ]
        
        
    go ( msg1 :|: msg2 ) =
      GHC.mkTyConApp concatTyCon [ go msg1, go msg2 ]
    go ( msg1 :-: msg2 ) =
      GHC.mkTyConApp vcatTyCon [ go msg1, go msg2 ]