{-# LANGUAGE RecordWildCards, ViewPatterns, NamedFieldPuns, OverloadedStrings #-}
module RecordDotPreprocessor(plugin) where
import Data.Generics.Uniplate.Data
import Data.List.Extra
import Data.Tuple.Extra
import Compat
import Bag
import qualified GHC
import qualified GhcPlugins as GHC
import qualified PrelNames as GHC
import SrcLoc
import TcEvidence
plugin :: GHC.Plugin
plugin = GHC.defaultPlugin
{ GHC.parsedResultAction = \_cliOptions _modSummary x -> pure x{GHC.hpm_module = onModule <$> GHC.hpm_module x}
, GHC.pluginRecompile = GHC.purePlugin
}
setL :: SrcSpan -> GenLocated SrcSpan e -> GenLocated SrcSpan e
setL l (L _ x) = L l x
mod_records :: GHC.ModuleName
mod_records = GHC.mkModuleName "GHC.Records.Extra"
var_HasField, var_hasField, var_getField, var_setField, var_dot :: GHC.RdrName
var_HasField = GHC.mkRdrQual mod_records $ GHC.mkClsOcc "HasField"
var_hasField = GHC.mkRdrUnqual $ GHC.mkVarOcc "hasField"
var_getField = GHC.mkRdrQual mod_records $ GHC.mkVarOcc "getField"
var_setField = GHC.mkRdrQual mod_records $ GHC.mkVarOcc "setField"
var_dot = GHC.mkRdrUnqual $ GHC.mkVarOcc "."
onModule :: HsModule GhcPs -> HsModule GhcPs
onModule x = x { hsmodImports = onImports $ hsmodImports x
, hsmodDecls = concatMap onDecl $ hsmodDecls x
}
onImports :: [LImportDecl GhcPs] -> [LImportDecl GhcPs]
onImports = (:) $ qualifiedImplicitImport mod_records
instanceTemplate :: FieldOcc GhcPs -> HsType GhcPs -> HsType GhcPs -> InstDecl GhcPs
instanceTemplate selector record field = ClsInstD noE $ ClsInstDecl noE (HsIB noE typ) (unitBag has) [] [] [] Nothing
where
typ' a = mkHsAppTys
(noL (HsTyVar noE GHC.NotPromoted (noL var_HasField)))
[noL (HsTyLit noE (HsStrTy GHC.NoSourceText (GHC.occNameFS $ GHC.occName $ unLoc $ rdrNameFieldOcc selector)))
,noL record
,noL a
]
typ = noL $ makeEqQualTy field (unLoc . typ')
has :: LHsBindLR GhcPs GhcPs
has = noL $ FunBind noE (noL var_hasField) (mg1 eqn) WpHole []
where
eqn = Match
{ m_ext = noE
, m_ctxt = FunRhs (noL var_hasField) GHC.Prefix NoSrcStrict
, m_pats = compat_m_pats [VarPat noE $ noL vR]
, m_grhss = GRHSs noE [noL $ GRHS noE [] $ noL $ ExplicitTuple noE [noL $ Present noE set, noL $ Present noE get] GHC.Boxed] (noL $ EmptyLocalBinds noE)
}
set = noL $ HsLam noE $ mg1 Match
{ m_ext = noE
, m_ctxt = LambdaExpr
, m_pats = compat_m_pats [VarPat noE $ noL vX]
, m_grhss = GRHSs noE [noL $ GRHS noE [] $ noL update] (noL $ EmptyLocalBinds noE)
}
update = RecordUpd noE (noL $ GHC.HsVar noE $ noL vR)
[noL $ HsRecField (noL (Unambiguous noE (rdrNameFieldOcc selector))) (noL $ GHC.HsVar noE $ noL vX) False]
get = mkApp
(mkParen $ mkTypeAnn (noL $ GHC.HsVar noE $ rdrNameFieldOcc selector) (noL $ HsFunTy noE (noL record) (noL field)))
(noL $ GHC.HsVar noE $ noL vR)
mg1 :: Match GhcPs (LHsExpr GhcPs) -> MatchGroup GhcPs (LHsExpr GhcPs)
mg1 x = MG noE (noL [noL x]) GHC.Generated
vR = GHC.mkRdrUnqual $ GHC.mkVarOcc "r"
vX = GHC.mkRdrUnqual $ GHC.mkVarOcc "x"
onDecl :: LHsDecl GhcPs -> [LHsDecl GhcPs]
onDecl o@(L _ (GHC.TyClD _ x)) = o :
[ noL $ InstD noE $ instanceTemplate field (unLoc record) (unbang typ)
| let fields = nubOrdOn (\(_,_,x,_) -> GHC.occNameFS $ GHC.rdrNameOcc $ unLoc $ rdrNameFieldOcc x) $ getFields x
, (record, _, field, typ) <- fields]
onDecl x = [descendBi onExp x]
unbang :: HsType GhcPs -> HsType GhcPs
unbang (HsBangTy _ _ x) = unLoc x
unbang x = x
getFields :: TyClDecl GhcPs -> [(LHsType GhcPs, IdP GhcPs, FieldOcc GhcPs, HsType GhcPs)]
getFields DataDecl{tcdDataDefn=HsDataDefn{..}, ..} = concatMap ctor dd_cons
where
ctor (L _ ConDeclH98{con_args=RecCon (L _ fields),con_name=L _ name}) = concatMap (field name) fields
ctor (L _ ConDeclGADT{con_args=RecCon (L _ fields),con_names=names}) = concat [field name fld | L _ name <- names, fld <- fields]
ctor _ = []
field name (L _ ConDeclField{cd_fld_type=L _ ty, ..}) = [(result, name, fld, ty) | L _ fld <- cd_fld_names]
field _ _ = error "unknown field declaration in getFields"
result = foldl (\x y -> noL $ HsAppTy noE x $ hsLTyVarBndrToType y) (noL $ HsTyVar noE GHC.NotPromoted tcdLName) $ hsq_explicit tcdTyVars
getFields _ = []
onExp :: LHsExpr GhcPs -> LHsExpr GhcPs
onExp (L o (OpApp _ lhs mid@(isDot -> True) rhs))
| adjacent lhs mid, adjacent mid rhs
, (lhsOp, lhs) <- getOpRHS $ onExp lhs
, (lhsApp, lhs) <- getAppRHS lhs
, (rhsApp, rhs) <- getAppLHS rhs
, (rhsRec, rhs) <- getRec rhs
, Just sel <- getSelector rhs
= onExp $ setL o $ lhsOp $ rhsApp $ lhsApp $ rhsRec $ mkParen $ mkVar var_getField `mkAppType` sel `mkApp` lhs
onExp (L o (SectionR _ mid@(isDot -> True) rhs))
| adjacent mid rhs
, srcSpanStart o == srcSpanStart (getLoc mid)
, srcSpanEnd o == srcSpanEnd (getLoc rhs)
, Just sels <- getSelectors rhs
= setL o $ foldl1 (\x y -> noL $ OpApp noE x (mkVar var_dot) y) $ map (mkVar var_getField `mkAppType`) $ reverse sels
onExp (L o upd@RecordUpd{rupd_expr,rupd_flds=fld:flds})
| adjacentBy 1 rupd_expr fld
= onExp $ f rupd_expr $ fld:flds
where
f expr [] = expr
f expr (L _ (HsRecField (fmap rdrNameAmbiguousFieldOcc -> lbl) arg pun) : flds)
| let sel = mkSelector lbl
, let arg2 = if pun then noL $ HsVar noE lbl else arg
, let expr2 = mkParen $ mkVar var_setField `mkAppType` sel `mkApp` expr `mkApp` arg2
= f expr2 flds
onExp x = descend onExp x
mkSelector :: Located GHC.RdrName -> LHsType GhcPs
mkSelector (L o x) = L o $ HsTyLit noE $ HsStrTy GHC.NoSourceText $ GHC.occNameFS $ GHC.rdrNameOcc x
getSelector :: LHsExpr GhcPs -> Maybe (LHsType GhcPs)
getSelector (L _ (HsVar _ (L o sym)))
| not $ GHC.isQual sym
= Just $ mkSelector $ L o sym
getSelector _ = Nothing
getSelectors :: LHsExpr GhcPs -> Maybe [LHsType GhcPs]
getSelectors (L _ (OpApp _ lhs mid@(isDot -> True) rhs))
| adjacent lhs mid, adjacent mid rhs
, Just post <- getSelector rhs
, Just pre <- getSelectors lhs
= Just $ pre ++ [post]
getSelectors x = (:[]) <$> getSelector x
getAppRHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getAppRHS (L l (HsApp e x y)) = (L l . HsApp e x, y)
getAppRHS x = (id, x)
getAppLHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getAppLHS (L l (HsApp e x y)) = first (\c -> L l . (\x -> HsApp e x y) . c) $ getAppLHS x
getAppLHS x = (id, x)
getOpRHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getOpRHS (L l (OpApp x y p z)) = (L l . OpApp x y p, z)
getOpRHS x = (id, x)
getRec :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getRec (L l r@RecordUpd{}) = first (\c x -> L l r{rupd_expr=setL (getLoc $ rupd_expr r) $ c x}) $ getRec $ rupd_expr r
getRec x = (id, x)
isDot :: LHsExpr GhcPs -> Bool
isDot (L _ (HsVar _ (L _ op))) = op == var_dot
isDot _ = False
mkVar :: GHC.RdrName -> LHsExpr GhcPs
mkVar = noL . HsVar noE . noL
mkParen :: LHsExpr GhcPs -> LHsExpr GhcPs
mkParen = noL . HsPar noE
mkApp :: LHsExpr GhcPs -> LHsExpr GhcPs -> LHsExpr GhcPs
mkApp x y = noL $ HsApp noE x y
adjacent :: Located a -> Located b -> Bool
adjacent = adjacentBy 0
adjacentBy :: Int -> Located a -> Located b -> Bool
adjacentBy i (L (srcSpanEnd -> RealSrcLoc a) _) (L (srcSpanStart -> RealSrcLoc b) _) =
srcLocFile a == srcLocFile b &&
srcLocLine a == srcLocLine b &&
srcLocCol a + i == srcLocCol b
adjacentBy _ _ _ = False
makeEqQualTy :: HsType GhcPs -> (HsType GhcPs -> HsType GhcPs) -> HsType GhcPs
makeEqQualTy rArg fAbs = HsQualTy noE (noL qualCtx) (noL (fAbs tyVar))
where
var = GHC.nameRdrName $ GHC.mkUnboundName $ GHC.mkTyVarOcc "aplg"
tyVar :: HsType GhcPs
tyVar = HsTyVar noE GHC.NotPromoted (noL var)
var_tilde = GHC.mkOrig GHC.gHC_TYPES $ GHC.mkClsOcc "~"
eqQual :: HsType GhcPs
eqQual = HsOpTy noE (noL (HsParTy noE (noL rArg))) (noLoc var_tilde) (noLoc tyVar)
qualCtx :: HsContext GhcPs
qualCtx = [noL (HsParTy noE (noL eqQual))]