{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Record.Plugin (
LargeRecordOptions(..)
, largeRecord
, plugin
) where
import Control.Monad.Except
import Control.Monad.Trans.Writer.CPS
import Data.List (intersperse)
import Data.Map.Strict (Map)
import Data.Set (Set)
import Data.Traversable (for)
import Language.Haskell.TH (Extension(..))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Record.Internal.Plugin.CodeGen (genLargeRecord)
import Data.Record.Internal.GHC.Fresh
import Data.Record.Internal.GHC.Shim
import Data.Record.Internal.GHC.TemplateHaskellStyle
import Data.Record.Internal.Plugin.Exception
import Data.Record.Internal.Plugin.Options
import Data.Record.Internal.Plugin.Record
import Data.Record.Internal.Plugin.Names
#if __GLASGOW_HASKELL__ >= 902
import GHC.Utils.Logger (getLogger)
#endif
#if __GLASGOW_HASKELL__ == 902
import GHC.Types.Error (mkWarnMsg, mkErr, mkDecorated)
import GHC.Driver.Errors (printOrThrowWarnings)
#endif
#if __GLASGOW_HASKELL__ >= 904
import GHC.Driver.Config.Diagnostic (initDiagOpts)
import GHC.Driver.Errors (printOrThrowDiagnostics)
import GHC.Driver.Errors.Types (GhcMessage(GhcUnknownMessage))
import GHC.Types.Error (mkPlainError, mkMessages, mkPlainDiagnostic)
import GHC.Utils.Error (mkMsgEnvelope, mkErrorMsgEnvelope)
#endif
plugin :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin {
parsedResultAction :: [String] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule
parsedResultAction = \[String]
_ ModSummary
_ -> forall {a}. a -> a
ignoreMessages HsParsedModule -> Hsc HsParsedModule
aux
, pluginRecompile :: [String] -> IO PluginRecompile
pluginRecompile = [String] -> IO PluginRecompile
purePlugin
}
where
#if __GLASGOW_HASKELL__ >= 904
ignoreMessages f (ParsedResult modl msgs) =
(\modl' -> ParsedResult modl' msgs) <$> f modl
#else
ignoreMessages :: a -> a
ignoreMessages = forall {a}. a -> a
id
#endif
aux :: HsParsedModule -> Hsc HsParsedModule
aux :: HsParsedModule -> Hsc HsParsedModule
aux parsed :: HsParsedModule
parsed@HsParsedModule{hpm_module :: HsParsedModule -> Located HsModule
hpm_module = Located HsModule
modl} = do
Located HsModule
modl' <- Located HsModule -> Hsc (Located HsModule)
transformDecls Located HsModule
modl
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ HsParsedModule
parsed { hpm_module :: Located HsModule
hpm_module = Located HsModule
modl' }
transformDecls :: LHsModule -> Hsc LHsModule
transformDecls :: Located HsModule -> Hsc (Located HsModule)
transformDecls (L SrcSpan
l modl :: HsModule
modl@HsModule{hsmodDecls :: HsModule -> [LHsDecl GhcPs]
hsmodDecls = [LHsDecl GhcPs]
decls}) = do
([[GenLocated SrcSpanAnnA (HsDecl GhcPs)]]
decls', Set String
transformed) <- forall w (m :: Type -> Type) a.
Monoid w =>
WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [LHsDecl GhcPs]
decls forall a b. (a -> b) -> a -> b
$ Map String [(SrcSpan, LargeRecordOptions)]
-> LHsDecl GhcPs -> WriterT (Set String) Hsc [LHsDecl GhcPs]
transformDecl Map String [(SrcSpan, LargeRecordOptions)]
largeRecords
SrcSpan -> Hsc ()
checkEnabledExtensions SrcSpan
l
let untransformed :: Set String
untransformed = forall k a. Map k a -> Set k
Map.keysSet Map String [(SrcSpan, LargeRecordOptions)]
largeRecords forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set String
transformed
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall a. Set a -> Bool
Set.null Set String
untransformed) forall a b. (a -> b) -> a -> b
$ do
SrcSpan -> SDoc -> Hsc ()
issueError SrcSpan
l forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat forall a b. (a -> b) -> a -> b
$
String -> SDoc
text String
"These large-record annotations were not applied:"
forall a. a -> [a] -> [a]
: [String -> SDoc
text (String
" - " forall a. [a] -> [a] -> [a]
++ String
n) | String
n <- forall a. Set a -> [a]
Set.toList Set String
untransformed]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpan
l forall a b. (a -> b) -> a -> b
$ HsModule
modl{hsmodDecls :: [LHsDecl GhcPs]
hsmodDecls = forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[GenLocated SrcSpanAnnA (HsDecl GhcPs)]]
decls'}
where
largeRecords :: Map String [(SrcSpan, LargeRecordOptions)]
largeRecords :: Map String [(SrcSpan, LargeRecordOptions)]
largeRecords = HsModule -> Map String [(SrcSpan, LargeRecordOptions)]
getLargeRecordOptions HsModule
modl
transformDecl ::
Map String [(SrcSpan, LargeRecordOptions)]
-> LHsDecl GhcPs
-> WriterT (Set String) Hsc [LHsDecl GhcPs]
transformDecl :: Map String [(SrcSpan, LargeRecordOptions)]
-> LHsDecl GhcPs -> WriterT (Set String) Hsc [LHsDecl GhcPs]
transformDecl Map String [(SrcSpan, LargeRecordOptions)]
largeRecords decl :: LHsDecl GhcPs
decl@(forall a e. LocatedAn a e -> Located e
reLoc -> L SrcSpan
l HsDecl GhcPs
_) =
case LHsDecl GhcPs
decl of
DataD (LRdrName -> String
nameBase -> String
name) [LHsTyVarBndr GhcPs]
_ [LConDecl GhcPs]
_ [LHsDerivingClause GhcPs]
_ ->
case forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] String
name Map String [(SrcSpan, LargeRecordOptions)]
largeRecords of
[] ->
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LHsDecl GhcPs
decl]
((SrcSpan, LargeRecordOptions)
_:(SrcSpan, LargeRecordOptions)
_:[(SrcSpan, LargeRecordOptions)]
_) -> do
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ SrcSpan -> SDoc -> Hsc ()
issueError SrcSpan
l forall a b. (a -> b) -> a -> b
$ String -> SDoc
text (String
"Conflicting annotations for " forall a. [a] -> [a] -> [a]
++ String
name)
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LHsDecl GhcPs
decl]
[(SrcSpan
annLoc, LargeRecordOptions
opts)] -> do
forall w (m :: Type -> Type).
(Monoid w, Monad m) =>
w -> WriterT w m ()
tell (forall a. a -> Set a
Set.singleton String
name)
case forall e a. Except e a -> Either e a
runExcept (forall (m :: Type -> Type).
MonadError Exception m =>
SrcSpan -> LargeRecordOptions -> LHsDecl GhcPs -> m Record
viewRecord SrcSpan
annLoc LargeRecordOptions
opts LHsDecl GhcPs
decl) of
Left Exception
e -> do
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ SrcSpan -> SDoc -> Hsc ()
issueError (Exception -> SrcSpan
exceptionLoc Exception
e) (Exception -> SDoc
exceptionToSDoc Exception
e)
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LHsDecl GhcPs
decl]
Right Record
r -> forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
DynFlags
dynFlags <- forall (m :: Type -> Type). HasDynFlags m => m DynFlags
getDynFlags
QualifiedNames
names <- Hsc QualifiedNames
getQualifiedNames
[GenLocated SrcSpanAnnA (HsDecl GhcPs)]
newDecls <- forall a. Fresh a -> Hsc a
runFreshHsc forall a b. (a -> b) -> a -> b
$ forall (m :: Type -> Type).
MonadFresh m =>
QualifiedNames -> Record -> DynFlags -> m [LHsDecl GhcPs]
genLargeRecord QualifiedNames
names Record
r DynFlags
dynFlags
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
when (LargeRecordOptions -> Bool
debugLargeRecords LargeRecordOptions
opts) forall a b. (a -> b) -> a -> b
$
SrcSpan -> SDoc -> Hsc ()
issueWarning SrcSpan
l ([LHsDecl GhcPs] -> SDoc
debugMsg [GenLocated SrcSpanAnnA (HsDecl GhcPs)]
newDecls)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure [GenLocated SrcSpanAnnA (HsDecl GhcPs)]
newDecls
LHsDecl GhcPs
_otherwise ->
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure [LHsDecl GhcPs
decl]
where
debugMsg :: [LHsDecl GhcPs] -> SDoc
debugMsg :: [LHsDecl GhcPs] -> SDoc
debugMsg [LHsDecl GhcPs]
newDecls = Depth -> SDoc -> SDoc
pprSetDepth Depth
AllTheWay forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat forall a b. (a -> b) -> a -> b
$
String -> SDoc
text String
"large-records: splicing in the following definitions:"
forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a. Outputable a => a -> SDoc
ppr [LHsDecl GhcPs]
newDecls
checkEnabledExtensions :: SrcSpan -> Hsc ()
checkEnabledExtensions :: SrcSpan -> Hsc ()
checkEnabledExtensions SrcSpan
l = do
DynFlags
dynFlags <- forall (m :: Type -> Type). HasDynFlags m => m DynFlags
getDynFlags
let missing :: [RequiredExtension]
missing :: [RequiredExtension]
missing = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynFlags -> RequiredExtension -> Bool
isEnabled DynFlags
dynFlags) [RequiredExtension]
requiredExtensions
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [RequiredExtension]
missing) forall a b. (a -> b) -> a -> b
$
SrcSpan -> SDoc -> Hsc ()
issueWarning SrcSpan
l forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ [
[String -> SDoc
text String
"Please enable these extensions for use with large-records:"]
, forall a b. (a -> b) -> [a] -> [b]
map forall a. Outputable a => a -> SDoc
ppr [RequiredExtension]
missing
]
where
requiredExtensions :: [RequiredExtension]
requiredExtensions :: [RequiredExtension]
requiredExtensions = [
[Extension] -> RequiredExtension
RequiredExtension [Extension
ConstraintKinds]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
DataKinds]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
ExistentialQuantification, Extension
GADTs]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
FlexibleInstances]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
MultiParamTypeClasses]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
ScopedTypeVariables]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
TypeFamilies]
, [Extension] -> RequiredExtension
RequiredExtension [Extension
UndecidableInstances]
]
data RequiredExtension = RequiredExtension [Extension]
instance Outputable RequiredExtension where
ppr :: RequiredExtension -> SDoc
ppr (RequiredExtension [Extension]
exts) = [SDoc] -> SDoc
hsep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a] -> [a]
intersperse (String -> SDoc
text String
"or") forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Outputable a => a -> SDoc
ppr [Extension]
exts
isEnabled :: DynFlags -> RequiredExtension -> Bool
isEnabled :: DynFlags -> RequiredExtension -> Bool
isEnabled DynFlags
dynflags (RequiredExtension [Extension]
exts) = forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (Extension -> DynFlags -> Bool
`xopt` DynFlags
dynflags) [Extension]
exts
issueError :: SrcSpan -> SDoc -> Hsc ()
issueError :: SrcSpan -> SDoc -> Hsc ()
issueError SrcSpan
l SDoc
errMsg = do
#if __GLASGOW_HASKELL__ == 902
forall (io :: Type -> Type) a.
MonadIO io =>
MsgEnvelope DecoratedSDoc -> io a
throwOneError forall a b. (a -> b) -> a -> b
$
forall e. SrcSpan -> PrintUnqualified -> e -> MsgEnvelope e
mkErr SrcSpan
l PrintUnqualified
neverQualify ([SDoc] -> DecoratedSDoc
mkDecorated [SDoc
errMsg])
#elif __GLASGOW_HASKELL__ >= 904
throwOneError $
mkErrorMsgEnvelope
l
neverQualify
(GhcUnknownMessage $ mkPlainError [] errMsg)
#else
dynFlags <- getDynFlags
throwOneError $
mkErrMsg dynFlags l neverQualify errMsg
#endif
issueWarning :: SrcSpan -> SDoc -> Hsc ()
issueWarning :: SrcSpan -> SDoc -> Hsc ()
issueWarning SrcSpan
l SDoc
errMsg = do
DynFlags
dynFlags <- forall (m :: Type -> Type). HasDynFlags m => m DynFlags
getDynFlags
#if __GLASGOW_HASKELL__ == 902
Logger
logger <- forall (m :: Type -> Type). HasLogger m => m Logger
getLogger
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Logger -> DynFlags -> Bag (MsgEnvelope DecoratedSDoc) -> IO ()
printOrThrowWarnings Logger
logger DynFlags
dynFlags forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Bag a
bag forall a b. (a -> b) -> a -> b
$
SrcSpan -> PrintUnqualified -> SDoc -> MsgEnvelope DecoratedSDoc
mkWarnMsg SrcSpan
l PrintUnqualified
neverQualify SDoc
errMsg
#elif __GLASGOW_HASKELL__ >= 904
logger <- getLogger
liftIO $ printOrThrowDiagnostics logger (initDiagOpts dynFlags) . mkMessages . bag $
mkMsgEnvelope
(initDiagOpts dynFlags)
l
neverQualify
(GhcUnknownMessage $ mkPlainDiagnostic WarningWithoutFlag [] errMsg)
#else
liftIO $ printOrThrowWarnings dynFlags . bag $
mkWarnMsg dynFlags l neverQualify errMsg
#endif
where
bag :: a -> Bag a
bag :: forall a. a -> Bag a
bag = forall a. [a] -> Bag a
listToBag forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[])