{-# LANGUAGE CPP                   #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module THSH.Internal.THUtils
  ( reportErrorAt
  , toName
  , lookupName
  , freeVariableByNameExists
  ) where

import           GHC                        (SrcSpan, moduleNameString)
import           GHC.Tc.Types               (TcM)
import           GHC.Tc.Utils.Monad         (addErrAt)
#if MIN_VERSION_ghc(9,8,0)
import           GHC.Tc.Errors.Types        (TcRnMessage (TcRnUnknownMessage))
import           GHC.Types.Error            (NoDiagnosticOpts (NoDiagnosticOpts), UnknownDiagnostic (UnknownDiagnostic))
import           GHC.Utils.Error            (mkPlainError, noHints)
import           GHC.Utils.Outputable       (text)
#elif MIN_VERSION_ghc(9,6,0)
import           GHC.Tc.Errors.Types        (TcRnMessage (TcRnUnknownMessage))
import           GHC.Types.Error            (UnknownDiagnostic (UnknownDiagnostic))
import           GHC.Utils.Error            (mkPlainError, noHints)
import           GHC.Utils.Outputable       (text)
#elif MIN_VERSION_ghc(9,4,0)
import           GHC.Driver.Errors.Types    (GhcMessage (GhcPsMessage))
import           GHC.Parser.Errors.Types    (PsMessage (PsUnknownMessage))
import           GHC.Tc.Errors.Types        (TcRnMessage (TcRnUnknownMessage))
import           GHC.Utils.Error            (mkPlainError, noHints)
import           GHC.Utils.Outputable       (text)
#else
import           Data.String                (fromString)
#endif
import           GHC.Types.Name             (getOccString, occNameString)
import           GHC.Types.Name.Reader      (RdrName (..))
import qualified GHC.Unit.Module            as Module
import qualified Language.Haskell.TH        as TH
import           Language.Haskell.TH.Syntax (Q (Q))
--
import           Data.Maybe                 (isJust)
import           Unsafe.Coerce              (unsafeCoerce)

-- | This function is similar to TH reportError, however it also provide
-- correct SrcSpan, so error are localised at the correct position in the TH
-- splice instead of being at the beginning.
--
-- From: PyF.Internal.QQ
reportErrorAt :: SrcSpan -> String -> Q ()
reportErrorAt :: SrcSpan -> String -> Q ()
reportErrorAt SrcSpan
loc String
msg = TcM () -> Q ()
forall a. TcM a -> Q a
unsafeRunTcM (TcM () -> Q ()) -> TcM () -> Q ()
forall a b. (a -> b) -> a -> b
$ SrcSpan -> TcRnMessage -> TcM ()
addErrAt SrcSpan
loc TcRnMessage
msg'
  where
#if MIN_VERSION_ghc(9,8,0)
    msg' :: TcRnMessage
msg' = UnknownDiagnostic (DiagnosticOpts TcRnMessage) -> TcRnMessage
TcRnUnknownMessage ((TcRnMessageOpts -> DiagnosticOpts DiagnosticMessage)
-> DiagnosticMessage -> UnknownDiagnostic TcRnMessageOpts
forall a opts.
(Diagnostic a, Typeable a) =>
(opts -> DiagnosticOpts a) -> a -> UnknownDiagnostic opts
UnknownDiagnostic (NoDiagnosticOpts -> TcRnMessageOpts -> NoDiagnosticOpts
forall a b. a -> b -> a
const NoDiagnosticOpts
NoDiagnosticOpts) ([GhcHint] -> SDoc -> DiagnosticMessage
mkPlainError [GhcHint]
noHints (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
msg)))
#elif MIN_VERSION_ghc(9,6,0)
    msg' = TcRnUnknownMessage (UnknownDiagnostic $ mkPlainError noHints $ text msg)
#elif MIN_VERSION_ghc(9,4,0)
    msg' = TcRnUnknownMessage (GhcPsMessage $ PsUnknownMessage $ mkPlainError noHints $ text msg)
#else
    msg' = fromString msg
#endif

-- Stolen from: https://www.tweag.io/blog/2021-01-07-haskell-dark-arts-part-i/
-- This allows to hack inside the the GHC api and use function not exported by template haskell.
-- This may not be always safe, see https://github.com/guibou/PyF/issues/115,
-- hence keep that for "failing path" (i.e. error reporting), but not on
-- codepath which are executed otherwise.
-- From: PyF.Internal.QQ
unsafeRunTcM :: TcM a -> Q a
unsafeRunTcM :: forall a. TcM a -> Q a
unsafeRunTcM TcM a
m = (forall (m :: * -> *). Quasi m => m a) -> Q a
forall a. (forall (m :: * -> *). Quasi m => m a) -> Q a
Q (TcM a -> m a
forall a b. a -> b
unsafeCoerce TcM a
m)

toName :: RdrName -> TH.Name
toName :: RdrName -> Name
toName RdrName
n = case RdrName
n of
  (Unqual OccName
o) -> String -> Name
TH.mkName (OccName -> String
occNameString OccName
o)
  (Qual ModuleName
m OccName
o) -> String -> Name
TH.mkName (ModuleName -> String
Module.moduleNameString ModuleName
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"." String -> String -> String
forall a. Semigroup a => a -> a -> a
<> OccName -> String
occNameString OccName
o)
  (Orig Module
_m OccName
_o) -> String -> Name
forall a. HasCallStack => String -> a
error String
"PyFMeta: not supported toName (Orig _)"
  (Exact Name
nm) -> case Name -> String
forall a. NamedThing a => a -> String
getOccString Name
nm of
    String
"[]" -> '[]
    String
"()" -> '()
    String
_    -> String -> Name
forall a. HasCallStack => String -> a
error String
"toName: exact name encountered"

lookupName :: RdrName -> Q Bool
lookupName :: RdrName -> Q Bool
lookupName RdrName
n = case RdrName
n of
  (Unqual OccName
o)   -> Maybe Name -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Name -> Bool) -> Q (Maybe Name) -> Q Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Q (Maybe Name)
TH.lookupValueName (OccName -> String
occNameString OccName
o)
  (Qual ModuleName
m OccName
o)   -> Maybe Name -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Name -> Bool) -> Q (Maybe Name) -> Q Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Q (Maybe Name)
TH.lookupValueName (ModuleName -> String
moduleNameString ModuleName
m String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"." String -> String -> String
forall a. Semigroup a => a -> a -> a
<> OccName -> String
occNameString OccName
o)
  -- No idea how to lookup for theses names, so consider that they exists
  (Orig Module
_m OccName
_o) -> Bool -> Q Bool
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  (Exact Name
_)    -> Bool -> Q Bool
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

freeVariableByNameExists :: (b, RdrName) -> Q (Maybe (String, b))
freeVariableByNameExists :: forall b. (b, RdrName) -> Q (Maybe (String, b))
freeVariableByNameExists (b
loc, RdrName
name) = do
  res <- RdrName -> Q Bool
lookupName RdrName
name
  if res
    then pure Nothing
    else pure (Just ("Variable not in scope: " <> show (toName name), loc))