{-# LANGUAGE RecordWildCards, ViewPatterns, NamedFieldPuns #-}
module RecordDotPreprocessor(plugin) where
import Data.Generics.Uniplate.Data
import Data.List.Extra
import Data.Tuple.Extra
import Compat
import Bag
import qualified GHC
import qualified GhcPlugins as GHC
import HsSyn
import SrcLoc
import TcEvidence
plugin :: GHC.Plugin
plugin = GHC.defaultPlugin
{ GHC.parsedResultAction = \_cliOptions _modSummary x -> return x{GHC.hpm_module = onModule <$> GHC.hpm_module x}
, GHC.pluginRecompile = GHC.purePlugin
}
setL :: SrcSpan -> GenLocated SrcSpan e -> GenLocated SrcSpan e
setL l (L _ x) = L l x
mod_records :: GHC.ModuleName
mod_records = GHC.mkModuleName "GHC.Records.Extra"
var_HasField, var_hasField, var_getField, var_setField, var_dot :: GHC.RdrName
var_HasField = GHC.mkRdrQual mod_records $ GHC.mkClsOcc "HasField"
var_hasField = GHC.mkRdrUnqual $ GHC.mkVarOcc "hasField"
var_getField = GHC.mkRdrQual mod_records $ GHC.mkVarOcc "getField"
var_setField = GHC.mkRdrQual mod_records $ GHC.mkVarOcc "setField"
var_dot = GHC.mkRdrUnqual $ GHC.mkVarOcc "."
onModule :: HsModule GhcPs -> HsModule GhcPs
onModule x = x { hsmodImports = onImports $ hsmodImports x
, hsmodDecls = concatMap onDecl $ hsmodDecls x
}
onImports :: [LImportDecl GhcPs] -> [LImportDecl GhcPs]
onImports = (:) $ noL $ GHC.ImportDecl GHC.NoExt GHC.NoSourceText (noL mod_records)
Nothing False False True True Nothing Nothing
instanceTemplate :: FieldOcc GhcPs -> HsType GhcPs -> HsType GhcPs -> InstDecl GhcPs
instanceTemplate selector record field = ClsInstD noE $ ClsInstDecl noE (HsIB noE typ) (unitBag has) [] [] [] Nothing
where
typ = mkHsAppTys
(noL (HsTyVar noE GHC.NotPromoted (noL var_HasField)))
[noL (HsTyLit noE (HsStrTy GHC.NoSourceText (GHC.occNameFS $ GHC.occName $ unLoc $ rdrNameFieldOcc selector)))
,noL record
,noL field
]
has :: LHsBindLR GhcPs GhcPs
has = noL $ FunBind noE (noL var_hasField) (mg1 eqn) WpHole []
where
eqn = Match
{ m_ext = noE
, m_ctxt = FunRhs (noL var_hasField) GHC.Prefix NoSrcStrict
, m_pats = compat_m_pats [VarPat noE $ noL vR]
, m_grhss = GRHSs noE [noL $ GRHS noE [] $ noL $ ExplicitTuple noE [noL $ Present noE set, noL $ Present noE get] GHC.Boxed] (noL $ EmptyLocalBinds noE)
}
set = noL $ HsLam noE $ mg1 Match
{ m_ext = noE
, m_ctxt = LambdaExpr
, m_pats = compat_m_pats [VarPat noE $ noL vX]
, m_grhss = GRHSs noE [noL $ GRHS noE [] $ noL update] (noL $ EmptyLocalBinds noE)
}
update = RecordUpd noE (noL $ GHC.HsVar noE $ noL vR)
[noL $ HsRecField (noL (Unambiguous noE (rdrNameFieldOcc selector))) (noL $ GHC.HsVar noE $ noL vX) False]
get = mkApp
(mkParen $ mkTypeAnn (noL $ GHC.HsVar noE $ rdrNameFieldOcc selector) (noL $ HsFunTy noE (noL record) (noL field)))
(noL $ GHC.HsVar noE $ noL vR)
mg1 :: Match GhcPs (LHsExpr GhcPs) -> MatchGroup GhcPs (LHsExpr GhcPs)
mg1 x = MG noE (noL [noL x]) GHC.Generated
vR = GHC.mkRdrUnqual $ GHC.mkVarOcc "r"
vX = GHC.mkRdrUnqual $ GHC.mkVarOcc "x"
onDecl :: LHsDecl GhcPs -> [LHsDecl GhcPs]
onDecl o@(L _ (GHC.TyClD _ x)) = o :
[ noL $ InstD noE $ instanceTemplate field (unLoc record) (unbang typ)
| let fields = nubOrdOn (\(_,_,x,_) -> GHC.occNameFS $ GHC.rdrNameOcc $ unLoc $ rdrNameFieldOcc x) $ getFields x
, (record, _, field, typ) <- fields]
onDecl x = [descendBi onExp x]
unbang :: HsType GhcPs -> HsType GhcPs
unbang (HsBangTy _ _ x) = unLoc x
unbang x = x
getFields :: TyClDecl GhcPs -> [(LHsType GhcPs, IdP GhcPs, FieldOcc GhcPs, HsType GhcPs)]
getFields DataDecl{tcdDataDefn=HsDataDefn{..}, ..} = concatMap ctor dd_cons
where
ctor (L _ ConDeclH98{con_args=RecCon (L _ fields),con_name=L _ name}) = concatMap (field name) fields
ctor _ = []
field name (L _ ConDeclField{cd_fld_type=L _ ty, ..}) = [(result, name, fld, ty) | L _ fld <- cd_fld_names]
field _ _ = error "unknown field declaration in getFields"
result = noL $ HsParTy noE $ foldl (\x y -> noL $ HsAppTy noE x $ hsLTyVarBndrToType y) (noL $ HsTyVar noE GHC.NotPromoted tcdLName) $ hsq_explicit tcdTyVars
getFields _ = []
onExp :: LHsExpr GhcPs -> LHsExpr GhcPs
onExp (L o (OpApp _ lhs mid@(isDot -> True) rhs))
| adjacent lhs mid, adjacent mid rhs
, (lhsOp, lhs) <- getOpRHS $ onExp lhs
, (lhsApp, lhs) <- getAppRHS lhs
, (rhsApp, rhs) <- getAppLHS rhs
, (rhsRec, rhs) <- getRec rhs
, Just sel <- getSelector rhs
= onExp $ lhsOp $ rhsApp $ lhsApp $ rhsRec $ mkParen $ setL o $ mkVar var_getField `mkAppType` sel `mkApp` lhs
onExp (L o (SectionR _ mid@(isDot -> True) rhs))
| adjacent mid rhs
, srcSpanStart o == srcSpanStart (getLoc mid)
, srcSpanEnd o == srcSpanEnd (getLoc rhs)
, Just sels <- getSelectors rhs
= setL o $ mkParen $ foldl1 (\x y -> noL $ OpApp noE x (mkVar var_dot) y) $ map (mkVar var_getField `mkAppType`) $ reverse sels
onExp (L _ RecordUpd{rupd_expr,rupd_flds=[]}) = onExp rupd_expr
onExp (L o upd@RecordUpd{rupd_expr,rupd_flds=L _ (HsRecField (fmap rdrNameAmbiguousFieldOcc -> lbl) arg pun):flds})
| let sel = mkSelector lbl
, let arg2 = if pun then noL $ HsVar noE lbl else arg
, let expr = mkParen $ mkVar var_setField `mkAppType` sel `mkApp` mkParen rupd_expr `mkApp` arg2
= onExp $ L o upd{rupd_expr=expr,rupd_flds=flds}
onExp x = descend onExp x
mkSelector :: Located GHC.RdrName -> LHsType GhcPs
mkSelector (L o x) = L o $ HsTyLit noE $ HsStrTy GHC.NoSourceText $ GHC.occNameFS $ GHC.rdrNameOcc x
getSelector :: LHsExpr GhcPs -> Maybe (LHsType GhcPs)
getSelector (L _ (HsVar _ (L o sym)))
| not $ GHC.isQual sym
= Just $ mkSelector $ L o sym
getSelector _ = Nothing
getSelectors :: LHsExpr GhcPs -> Maybe [LHsType GhcPs]
getSelectors (L _ (OpApp _ lhs mid@(isDot -> True) rhs))
| adjacent lhs mid, adjacent mid rhs
, Just post <- getSelector rhs
, Just pre <- getSelectors lhs
= Just $ pre ++ [post]
getSelectors x = (:[]) <$> getSelector x
getAppRHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getAppRHS (L l (HsApp e x y)) = (L l . HsApp e x, y)
getAppRHS x = (id, x)
getAppLHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getAppLHS (L l (HsApp e x y)) = first (\c -> L l . (\x -> HsApp e x y) . c) $ getAppLHS x
getAppLHS x = (id, x)
getOpRHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getOpRHS (L l (OpApp x y p z)) = (L l . OpApp x y p, z)
getOpRHS x = (id, x)
getRec :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getRec (L l r@RecordUpd{}) = first (\c x -> L l r{rupd_expr=c x}) $ getRec $ rupd_expr r
getRec x = (id, x)
isDot :: LHsExpr GhcPs -> Bool
isDot (L _ (HsVar _ (L _ op))) = op == var_dot
isDot _ = False
mkVar :: GHC.RdrName -> LHsExpr GhcPs
mkVar = noL . HsVar noE . noL
mkParen :: LHsExpr GhcPs -> LHsExpr GhcPs
mkParen = noL . HsPar noE
mkApp :: LHsExpr GhcPs -> LHsExpr GhcPs -> LHsExpr GhcPs
mkApp x y = noL $ HsApp noE x y
adjacent :: Located a -> Located b -> Bool
adjacent (L a _) (L b _) = isGoodSrcSpan a && srcSpanEnd a == srcSpanStart b