{-# LANGUAGE CPP            #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ViewPatterns   #-}

-- | Support for scalable large records
--
-- = Usage
--
-- > {-# OPTIONS_GHC -fplugin=Data.Record.Plugin #-}
-- >
-- > import Data.Record.Plugin
-- >
-- > {-# ANN type B largeRecord #-}
-- > data B a = B {a :: a, b :: String}
-- >   deriving stock (Show, Eq, Ord)
--
-- See 'LargeRecordOptions' for the list of all possible annotations.
--
-- = Usage with @record-dot-preprocessor@
--
-- The easiest way to use both plugins together is to do
--
-- > {-# OPTIONS_GHC -fplugin=Data.Record.Plugin.WithRDP #-}
--
-- You /can/ also load them separately, but if you do, you need to be careful
-- with the order. Unfortunately, the correct order is different in different
-- ghc versions. Prior to ghc 9.4, the plugins must be loaded like this:
--
-- > {-# OPTIONS_GHC -fplugin=RecordDotPreprocessor -fplugin=Data.Record.Plugin #-}
--
-- From ghc 9.4 and up, they need to be loaded in the opposite order:
--
-- > {-# OPTIONS_GHC -fplugin=Data.Record.Plugin -fplugin=RecordDotPreprocessor #-}
module Data.Record.Plugin (
    -- * Annotations
    LargeRecordOptions(..)
  , largeRecord
    -- * For use by ghc
  , plugin
  ) where

import Control.Monad.Except
import Control.Monad.Trans.Writer.CPS
import Data.List (intersperse)
import Data.Map.Strict (Map)
import Data.Set (Set)
import Data.Traversable (for)
import Language.Haskell.TH (Extension(..))

import qualified Data.Map.Strict as Map
import qualified Data.Set        as Set

import Data.Record.Internal.Plugin.CodeGen (genLargeRecord)
import Data.Record.Internal.GHC.Fresh
import Data.Record.Internal.GHC.Shim
import Data.Record.Internal.GHC.TemplateHaskellStyle
import Data.Record.Internal.Plugin.Exception
import Data.Record.Internal.Plugin.Options
import Data.Record.Internal.Plugin.Record
import Data.Record.Internal.Plugin.Names

#if __GLASGOW_HASKELL__ >= 902
import GHC.Utils.Logger (getLogger)
#endif

#if __GLASGOW_HASKELL__ == 902
import GHC.Types.Error (mkWarnMsg, mkErr, mkDecorated)
import GHC.Driver.Errors (printOrThrowWarnings)
#endif

#if __GLASGOW_HASKELL__ >= 904
import GHC.Driver.Config.Diagnostic (initDiagOpts)
import GHC.Driver.Errors (printOrThrowDiagnostics)
import GHC.Driver.Errors.Types (GhcMessage(GhcUnknownMessage))
import GHC.Types.Error (mkPlainError, mkMessages, mkPlainDiagnostic)
import GHC.Utils.Error (mkMsgEnvelope, mkErrorMsgEnvelope)
#endif

{-------------------------------------------------------------------------------
  Top-level: the plugin proper
-------------------------------------------------------------------------------}

plugin :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin {
      parsedResultAction :: [String] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule
parsedResultAction = \[String]
_ ModSummary
_ -> forall {a}. a -> a
ignoreMessages HsParsedModule -> Hsc HsParsedModule
aux
    , pluginRecompile :: [String] -> IO PluginRecompile
pluginRecompile    = [String] -> IO PluginRecompile
purePlugin
    }
  where
#if __GLASGOW_HASKELL__ >= 904
    ignoreMessages f (ParsedResult modl msgs) =
            (\modl' -> ParsedResult modl' msgs) <$> f modl
#else
    ignoreMessages :: a -> a
ignoreMessages = forall {a}. a -> a
id
#endif

    aux :: HsParsedModule -> Hsc HsParsedModule
    aux :: HsParsedModule -> Hsc HsParsedModule
aux parsed :: HsParsedModule
parsed@HsParsedModule{hpm_module :: HsParsedModule -> Located HsModule
hpm_module = Located HsModule
modl} = do
        Located HsModule
modl' <- Located HsModule -> Hsc (Located HsModule)
transformDecls Located HsModule
modl
        forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ HsParsedModule
parsed { hpm_module :: Located HsModule
hpm_module = Located HsModule
modl' }

{-------------------------------------------------------------------------------
  Transform datatype declarations
-------------------------------------------------------------------------------}

transformDecls :: LHsModule -> Hsc LHsModule
transformDecls :: Located HsModule -> Hsc (Located HsModule)
transformDecls (L SrcSpan
l modl :: HsModule
modl@HsModule{hsmodDecls :: HsModule -> [LHsDecl GhcPs]
hsmodDecls = [LHsDecl GhcPs]
decls}) = do
    ([[GenLocated SrcSpanAnnA (HsDecl GhcPs)]]
decls', Set String
transformed) <- forall w (m :: Type -> Type) a.
Monoid w =>
WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [LHsDecl GhcPs]
decls forall a b. (a -> b) -> a -> b
$ Map String [(SrcSpan, LargeRecordOptions)]
-> LHsDecl GhcPs -> WriterT (Set String) Hsc [LHsDecl GhcPs]
transformDecl Map String [(SrcSpan, LargeRecordOptions)]
largeRecords

    SrcSpan -> Hsc ()
checkEnabledExtensions SrcSpan
l

    -- Check for annotations without corresponding types
    let untransformed :: Set String
untransformed = forall k a. Map k a -> Set k
Map.keysSet Map String [(SrcSpan, LargeRecordOptions)]
largeRecords forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set String
transformed
    forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall a. Set a -> Bool
Set.null Set String
untransformed) forall a b. (a -> b) -> a -> b
$ do
      SrcSpan -> SDoc -> Hsc ()
issueError SrcSpan
l forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat forall a b. (a -> b) -> a -> b
$
          String -> SDoc
text String
"These large-record annotations were not applied:"
        forall a. a -> [a] -> [a]
: [String -> SDoc
text (String
" - " forall a. [a] -> [a] -> [a]
++ String
n) | String
n <- forall a. Set a -> [a]
Set.toList Set String
untransformed]

    -- We add imports whether or not there were some errors, to avoid spurious
    -- additional errors from ghc about things not in scope.
    forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpan
l forall a b. (a -> b) -> a -> b
$ HsModule
modl{hsmodDecls :: [LHsDecl GhcPs]
hsmodDecls = forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[GenLocated SrcSpanAnnA (HsDecl GhcPs)]]
decls'}
  where
    largeRecords :: Map String [(SrcSpan, LargeRecordOptions)]
    largeRecords :: Map String [(SrcSpan, LargeRecordOptions)]
largeRecords = HsModule -> Map String [(SrcSpan, LargeRecordOptions)]
getLargeRecordOptions HsModule
modl

transformDecl ::
     Map String [(SrcSpan, LargeRecordOptions)]
  -> LHsDecl GhcPs
  -> WriterT (Set String) Hsc [LHsDecl GhcPs]
transformDecl :: Map String [(SrcSpan, LargeRecordOptions)]
-> LHsDecl GhcPs -> WriterT (Set String) Hsc [LHsDecl GhcPs]
transformDecl Map String [(SrcSpan, LargeRecordOptions)]
largeRecords decl :: LHsDecl GhcPs
decl@(forall a e. LocatedAn a e -> Located e
reLoc -> L SrcSpan
l HsDecl GhcPs
_) =
    case LHsDecl GhcPs
decl of
      DataD (LRdrName -> String
nameBase -> String
name) [LHsTyVarBndr GhcPs]
_ [LConDecl GhcPs]
_ [LHsDerivingClause GhcPs]
_  ->
        case forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] String
name Map String [(SrcSpan, LargeRecordOptions)]
largeRecords of
          [] ->
            -- Not a large record. Leave alone.
            forall (m :: Type -> Type) a. Monad m => a -> m a
return [LHsDecl GhcPs
decl]
          ((SrcSpan, LargeRecordOptions)
_:(SrcSpan, LargeRecordOptions)
_:[(SrcSpan, LargeRecordOptions)]
_) -> do
            forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ SrcSpan -> SDoc -> Hsc ()
issueError SrcSpan
l forall a b. (a -> b) -> a -> b
$ String -> SDoc
text (String
"Conflicting annotations for " forall a. [a] -> [a] -> [a]
++ String
name)
            forall (m :: Type -> Type) a. Monad m => a -> m a
return [LHsDecl GhcPs
decl]
          [(SrcSpan
annLoc, LargeRecordOptions
opts)] -> do
            forall w (m :: Type -> Type).
(Monoid w, Monad m) =>
w -> WriterT w m ()
tell (forall a. a -> Set a
Set.singleton String
name)
            case forall e a. Except e a -> Either e a
runExcept (forall (m :: Type -> Type).
MonadError Exception m =>
SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord SrcSpan
annLoc LargeRecordOptions
opts LHsDecl GhcPs
decl) of
              Left Exception
e -> do
                forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ SrcSpan -> SDoc -> Hsc ()
issueError (Exception -> SrcSpan
exceptionLoc Exception
e) (Exception -> SDoc
exceptionToSDoc Exception
e)
                -- Return the declaration unchanged if we cannot parse it
                forall (m :: Type -> Type) a. Monad m => a -> m a
return [LHsDecl GhcPs
decl]
              Right Record
r -> forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
                DynFlags
dynFlags <- forall (m :: Type -> Type). HasDynFlags m => m DynFlags
getDynFlags
                QualifiedNames
names    <- Hsc QualifiedNames
getQualifiedNames
                [GenLocated SrcSpanAnnA (HsDecl GhcPs)]
newDecls <- forall a. Fresh a -> Hsc a
runFreshHsc forall a b. (a -> b) -> a -> b
$ forall (m :: Type -> Type).
MonadFresh m =>
QualifiedNames -> Record -> DynFlags -> m [LHsDecl GhcPs]
genLargeRecord QualifiedNames
names Record
r DynFlags
dynFlags
                forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
when (LargeRecordOptions -> Bool
debugLargeRecords LargeRecordOptions
opts) forall a b. (a -> b) -> a -> b
$
                  SrcSpan -> SDoc -> Hsc ()
issueWarning SrcSpan
l ([LHsDecl GhcPs] -> SDoc
debugMsg [GenLocated SrcSpanAnnA (HsDecl GhcPs)]
newDecls)
                forall (f :: Type -> Type) a. Applicative f => a -> f a
pure [GenLocated SrcSpanAnnA (HsDecl GhcPs)]
newDecls
      LHsDecl GhcPs
_otherwise ->
        forall (f :: Type -> Type) a. Applicative f => a -> f a
pure [LHsDecl GhcPs
decl]
  where
    debugMsg :: [LHsDecl GhcPs] -> SDoc
    debugMsg :: [LHsDecl GhcPs] -> SDoc
debugMsg [LHsDecl GhcPs]
newDecls = Depth -> SDoc -> SDoc
pprSetDepth Depth
AllTheWay forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat forall a b. (a -> b) -> a -> b
$
          String -> SDoc
text String
"large-records: splicing in the following definitions:"
        forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a. Outputable a => a -> SDoc
ppr [LHsDecl GhcPs]
newDecls

{-------------------------------------------------------------------------------
  Check for enabled extensions

  In ghc 8.10 and up there are DynFlags plugins, which we could use to enable
  these extensions for the user. Since this is not available in 8.8 however we
  will not make use of this for now. (There is also reason to believe that these
  may be removed again in later ghc releases.)
-------------------------------------------------------------------------------}

checkEnabledExtensions :: SrcSpan -> Hsc ()
checkEnabledExtensions :: SrcSpan -> Hsc ()
checkEnabledExtensions SrcSpan
l = do
    DynFlags
dynFlags <- forall (m :: Type -> Type). HasDynFlags m => m DynFlags
getDynFlags
    let missing :: [RequiredExtension]
        missing :: [RequiredExtension]
missing = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynFlags -> RequiredExtension -> Bool
isEnabled DynFlags
dynFlags) [RequiredExtension]
requiredExtensions
    forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [RequiredExtension]
missing) forall a b. (a -> b) -> a -> b
$
      -- We issue a warning here instead of an error, for better integration
      -- with HLS. Frankly, I'm not entirely sure what's going on there.
      SrcSpan -> SDoc -> Hsc ()
issueWarning SrcSpan
l forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ [
          [String -> SDoc
text String
"Please enable these extensions for use with large-records:"]
        , forall a b. (a -> b) -> [a] -> [b]
map forall a. Outputable a => a -> SDoc
ppr [RequiredExtension]
missing
        ]
  where
    requiredExtensions :: [RequiredExtension]
    requiredExtensions :: [RequiredExtension]
requiredExtensions = [
          [Extension] -> RequiredExtension
RequiredExtension [Extension
ConstraintKinds]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
DataKinds]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
ExistentialQuantification, Extension
GADTs]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
FlexibleInstances]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
MultiParamTypeClasses]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
ScopedTypeVariables]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
TypeFamilies]
        , [Extension] -> RequiredExtension
RequiredExtension [Extension
UndecidableInstances]
        ]

-- | Required extension
--
-- The list is used to represent alternative extensions that could all work
-- (e.g., @GADTs@ and @ExistentialQuantification@).
data RequiredExtension = RequiredExtension [Extension]

instance Outputable RequiredExtension where
  ppr :: RequiredExtension -> SDoc
ppr (RequiredExtension [Extension]
exts) = [SDoc] -> SDoc
hsep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a] -> [a]
intersperse (String -> SDoc
text String
"or") forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Outputable a => a -> SDoc
ppr [Extension]
exts

isEnabled :: DynFlags -> RequiredExtension -> Bool
isEnabled :: DynFlags -> RequiredExtension -> Bool
isEnabled DynFlags
dynflags (RequiredExtension [Extension]
exts) = forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (Extension -> DynFlags -> Bool
`xopt` DynFlags
dynflags) [Extension]
exts

{-------------------------------------------------------------------------------
  Internal auxiliary
-------------------------------------------------------------------------------}

issueError :: SrcSpan -> SDoc -> Hsc ()
issueError :: SrcSpan -> SDoc -> Hsc ()
issueError SrcSpan
l SDoc
errMsg = do
#if __GLASGOW_HASKELL__ == 902
    forall (io :: Type -> Type) a.
MonadIO io =>
MsgEnvelope DecoratedSDoc -> io a
throwOneError forall a b. (a -> b) -> a -> b
$
      forall e. SrcSpan -> PrintUnqualified -> e -> MsgEnvelope e
mkErr SrcSpan
l PrintUnqualified
neverQualify ([SDoc] -> DecoratedSDoc
mkDecorated [SDoc
errMsg])
#elif __GLASGOW_HASKELL__ >= 904
    throwOneError $
      mkErrorMsgEnvelope
        l
        neverQualify
        (GhcUnknownMessage $ mkPlainError [] errMsg)
#else
    dynFlags <- getDynFlags
    throwOneError $
      mkErrMsg dynFlags l neverQualify errMsg
#endif

issueWarning :: SrcSpan -> SDoc -> Hsc ()
issueWarning :: SrcSpan -> SDoc -> Hsc ()
issueWarning SrcSpan
l SDoc
errMsg = do
    DynFlags
dynFlags <- forall (m :: Type -> Type). HasDynFlags m => m DynFlags
getDynFlags
#if __GLASGOW_HASKELL__ == 902
    Logger
logger <- forall (m :: Type -> Type). HasLogger m => m Logger
getLogger
    forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Logger -> DynFlags -> Bag (MsgEnvelope DecoratedSDoc) -> IO ()
printOrThrowWarnings Logger
logger DynFlags
dynFlags forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Bag a
bag forall a b. (a -> b) -> a -> b
$
      SrcSpan -> PrintUnqualified -> SDoc -> MsgEnvelope DecoratedSDoc
mkWarnMsg SrcSpan
l PrintUnqualified
neverQualify SDoc
errMsg
#elif __GLASGOW_HASKELL__ >= 904
    logger <- getLogger
    liftIO $ printOrThrowDiagnostics logger (initDiagOpts dynFlags) . mkMessages . bag $
      mkMsgEnvelope
        (initDiagOpts dynFlags)
        l
        neverQualify
        (GhcUnknownMessage $ mkPlainDiagnostic WarningWithoutFlag [] errMsg)
#else
    liftIO $ printOrThrowWarnings dynFlags . bag $
      mkWarnMsg dynFlags l neverQualify errMsg
#endif
  where
    bag :: a -> Bag a
    bag :: forall a. a -> Bag a
bag = forall a. [a] -> Bag a
listToBag forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[])