{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE CPP #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE BangPatterns #-} -------------------------------------------------------------------------------- -- | -- Module : GHC.Plugin.LensThRewrite -- Copyright : (c) 2020 David Johnson -- License : All Rights Reserved -- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- -- GHC Plugin to rewrite makeLenses call into pure functions. -- -------------------------------------------------------------------------------- module GHC.Plugin.LensThRewrite ( plugin, rewriteModule ) where import Control.Arrow import Control.Lens import Data.Function (on) import Data.List import CoreSyn import GhcPlugins import HsDecls import HsDumpAst import HsExtension import HsSyn import OccName import RdrName import TcEvidence import Var -- | Lens rewrite plugin. plugin :: Plugin plugin = defaultPlugin { parsedResultAction = \_ _ -> rewriteMakeLenses , pluginRecompile = purePlugin } rewriteModule :: HsModule GhcPs -> HsModule GhcPs rewriteModule module' = module' & decls %~ concatMap (modifyDecls module') rewriteMakeLenses :: HsParsedModule -> Hsc HsParsedModule rewriteMakeLenses parsed = pure $ parsed & moduleDecls %~ concatMap (modifyDecls (parsed ^. (hsMod . located))) parsedModule :: Lens' HsParsedModule (Located (HsModule GhcPs)) parsedModule = lens hpm_module $ \r f -> r { hpm_module = f } moduleDecls :: Lens' HsParsedModule [LHsDecl GhcPs] moduleDecls = hsMod . located . decls hsMod :: Lens' HsParsedModule (Located (HsModule GhcPs)) hsMod = lens hpm_module $ \f r -> f { hpm_module = r } located :: Lens' (Located a) a located = lens getter setter where getter (L _ r) = r setter (L x _) y = L x y decls :: Lens' (HsModule a) [LHsDecl a] decls = lens hsmodDecls $ \f r -> f { hsmodDecls = r } modifyDecls :: HsModule GhcPs -> LHsDecl GhcPs -> [LHsDecl GhcPs] modifyDecls m (L x decl) | not (isMakeLensesThSplice decl) = pure (L x decl) modifyDecls m (L x decl) = emptyL <$> toDecls decl where toDecls :: HsDecl GhcPs -> [HsDecl GhcPs] toDecls decl = concat $ genDecls (getDecls m) =<< getMakeLensesSplices decl where genDecls decls type' | Just fields <- lookup type' decls = genLensCall type' <$> fields | otherwise = [] isMakeLensesThSplice :: HsDecl GhcPs -> Bool isMakeLensesThSplice (SpliceD _ (SpliceDecl _ (L _ splice) _)) = case splice of HsUntypedSplice _ _ _ (L _ expr) -> case expr of HsApp _ (L _ l) (L _ r) -> case l of HsVar NoExt (L _ (Unqual (occNameString -> "makeLenses"))) -> True _ -> False _ -> False _ -> False isMakeLensesThSplice _ = False getMakeLensesSplices :: HsDecl GhcPs -> [String] getMakeLensesSplices (SpliceD _ (SpliceDecl _ (L _ splice) _)) = case splice of HsUntypedSplice _ _ _ (L _ expr) -> case expr of HsApp _ (L _ l) (L _ r) -> case l of HsVar NoExt (L _ (Unqual (occNameString -> "makeLenses"))) -> case r of HsBracket NoExt (VarBr NoExt False (Unqual (occNameString -> typ))) -> [typ] _ -> [] _ -> [] _ -> [] _ -> [] getMakeLensesSplices _ = [] mkVar :: String -> HsExpr GhcPs mkVar x = HsVar NoExt (mkName x) type FieldName = String type TypeName = String genSigD :: FieldName -> TypeName -- ^ Inner type, i.e. "Person" in Lens' Person Int -> HsType GhcPs -- ^ Outer type, i.e. "Int" in Lens' Person Int -> HsDecl GhcPs genSigD fieldName innerType outerType = SigD NoExt (TypeSig NoExt [ mkName fieldName ] hsWc) where hsWc = HsWC NoExt hsIb hsIb = HsIB NoExt (emptyL result) result = tyVarLens `appTy` tyVarTypeInner `appTy` outerType tyVarTypeInner = tyVar innerType appTy :: HsType GhcPs -> HsType GhcPs -> HsType GhcPs appTy = HsAppTy NoExt `on` emptyL tyVarLens :: HsType GhcPs tyVarLens = tyC "Lens'" tyVar :: String -> HsType GhcPs tyVar s = HsTyVar NoExt NotPromoted (mkTyVarName s) tyC :: String -> HsType GhcPs tyC s = HsTyVar NoExt NotPromoted (mkTyCName s) getDecls :: HsModule GhcPs -> [(TypeName, [(FieldName, HsType GhcPs)])] getDecls mod = concatMap go $ fmap (^. located) (hsmodDecls mod) where go :: HsDecl GhcPs -> [(String, [(String,HsType GhcPs)])] go (TyClD NoExt d) = [(getDeclTypeName &&& getFieldAndTypeName) d] go _ = [] mkName :: String -> Located RdrName mkName = emptyL . mkRdrUnqual . mkOccName OccName.varName mkTyVarName :: String -> Located RdrName mkTyVarName = emptyL . mkRdrUnqual . mkOccName OccName.tcName mkTyCName :: String -> Located RdrName mkTyCName = emptyL . mkRdrUnqual . mkOccName OccName.tcName -- | Extract existing type information from a Type or class Decl getDeclTypeName :: TyClDecl GhcPs -> String getDeclTypeName DataDecl {..} = case tcdLName ^. located of Unqual (occNameString -> s) -> s getDeclTypeName _ = mempty -- | Extract field name information from a record getFieldAndTypeName :: TyClDecl GhcPs -> [(String,HsType GhcPs)] getFieldAndTypeName DataDecl {..} = concat . concat $ dd_cons tcdDataDefn <&> \(L _ ConDeclH98 {..}) -> case con_args of RecCon (L _ xs) -> xs <&> \(L _ ConDeclField{..}) -> case cd_fld_names of [ L _ FieldOcc {..} ] -> case rdrNameFieldOcc of L _ (Unqual fieldName) -> pure (occNameString fieldName, cd_fld_type ^. located) _ -> [] _ -> [] _ -> [] getFieldAndTypeName _ = [] genLensCall :: String -> (String, HsType GhcPs) -> [HsDecl GhcPs] genLensCall lensInnerType (fieldName, fieldType) = [ genSigD lensName lensInnerType fieldType, valD (funBind lensName mg) ] where lensName = drop 1 fieldName mg = matchGroup [ match (funRhs lensName) $ grhss [ grhs $ (hsVar "lens" `hsApp` hsVar fieldName) `hsApp` (hsPar $ hsLam $ matchGroup #if MIN_VERSION_base (4,13,0) [ lambdaMatch [ varPat "r", varPat "f" ] $ #else [ lambdaMatch [ emptyL (varPat "r"), emptyL (varPat "f") ] $ #endif grhss [ grhs $ recordUpd (hsVar "r") [ hsRecUpdField fieldName (hsVar "f") ] ] ] ) ] ] valD :: HsBind GhcPs -> HsDecl GhcPs valD = ValD NoExt funBind :: String -> MatchGroup GhcPs (LHsExpr GhcPs) -> HsBind GhcPs funBind s mg = FunBind NoExt (mkName s) mg WpHole [] matchGroup :: [Match GhcPs (LHsExpr GhcPs)] -> MatchGroup GhcPs (LHsExpr GhcPs) matchGroup xs = MG NoExt (emptyL (fmap emptyL xs)) FromSource match :: HsMatchContext (NameOrRdrName (IdP GhcPs)) -- see funRhs -> GRHSs GhcPs (LHsExpr GhcPs) -> Match GhcPs (LHsExpr GhcPs) match x y = Match NoExt x [] y lambdaMatch :: [LPat GhcPs] -> GRHSs GhcPs (LHsExpr GhcPs) -> Match GhcPs (LHsExpr GhcPs) lambdaMatch xs y = Match NoExt LambdaExpr xs y funRhs :: String -> HsMatchContext (NameOrRdrName (IdP GhcPs)) funRhs x = FunRhs (mkName x) Prefix NoSrcStrict emptyL :: e -> GenLocated SrcSpan e emptyL = L noSrcSpan grhss :: [GRHS GhcPs (LHsExpr GhcPs)] -> GRHSs GhcPs (LHsExpr GhcPs) grhss xs = GRHSs NoExt (fmap emptyL xs) (emptyL (EmptyLocalBinds NoExt)) grhs :: HsExpr GhcPs -> GRHS GhcPs (LHsExpr GhcPs) grhs = GRHS NoExt [] . emptyL hsApp :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs hsApp l r = HsApp NoExt (emptyL l) (emptyL r) hsVar :: String -> HsExpr GhcPs hsVar = HsVar NoExt . mkName hsPar :: HsExpr GhcPs -> HsExpr GhcPs hsPar = HsPar NoExt . emptyL hsLam :: MatchGroup GhcPs (LHsExpr GhcPs) -> HsExpr GhcPs hsLam = HsLam NoExt recordUpd :: HsExpr GhcPs -> [HsRecUpdField GhcPs] -> HsExpr GhcPs recordUpd e fs = RecordUpd NoExt (emptyL e) (emptyL <$> fs) hsRecUpdField :: String -> HsExpr GhcPs -> HsRecUpdField GhcPs hsRecUpdField s e = HsRecField (emptyL (ambig s)) (emptyL e) False where ambig :: String -> AmbiguousFieldOcc GhcPs ambig s = Unambiguous NoExt (mkName s) varPat :: String -> Pat GhcPs varPat = VarPat NoExt . mkName