{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE BangPatterns #-} -------------------------------------------------------------------------------- -- | -- Module : GHC.Plugin.LensThRewrite -- Copyright : (c) 2020 David Johnson -- License : All Rights Reserved -- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- -- GHC Plugin to rewrite makeLenses call into pure functions. -- -------------------------------------------------------------------------------- module GHC.Plugin.LensThRewrite ( plugin ) where import Control.Arrow import Control.Lens import Data.Function (on) import Data.List import CoreSyn import GhcPlugins import HsDecls import HsDumpAst import HsExtension import HsSyn import OccName import RdrName import TcEvidence import Var import System.IO.Unsafe -- | Lens rewrite plugin. plugin :: Plugin plugin = defaultPlugin { parsedResultAction = \_ _ -> rewriteMakeLenses , pluginRecompile = purePlugin } rewriteMakeLenses :: HsParsedModule -> Hsc HsParsedModule rewriteMakeLenses parsed = do -- liftIO $ print "Rewriting makeLenses to use lens" pure $ parsed & moduleDecls %~ concatMap (modifyDecls (parsed ^. (hsMod . located))) parsedModule :: Lens' HsParsedModule (Located (HsModule GhcPs)) parsedModule = lens hpm_module $ \r f -> r { hpm_module = f } moduleDecls :: Lens' HsParsedModule [LHsDecl GhcPs] moduleDecls = hsMod . located . decls hsMod :: Lens' HsParsedModule (Located (HsModule GhcPs)) hsMod = lens hpm_module $ \f r -> f { hpm_module = r } located :: Lens' (Located a) a located = lens getter setter where getter (L _ r) = r setter (L x _) y = L x y decls :: Lens' (HsModule a) [LHsDecl a] decls = lens hsmodDecls $ \f r -> f { hsmodDecls = r } modifyDecls :: HsModule GhcPs -> LHsDecl GhcPs -> [LHsDecl GhcPs] modifyDecls m (L x decl) | not (isMakeLensesThSplice decl) = pure (L x decl) modifyDecls m (L x decl) = emptyL <$> toDecls decl where toDecls :: HsDecl GhcPs -> [HsDecl GhcPs] toDecls decl = concat $ genDecls (getDecls m) =<< getMakeLensesSplices decl where genDecls decls type' | Just fields <- lookup type' decls = genLensCall type' <$> fields | otherwise = [] isMakeLensesThSplice :: HsDecl GhcPs -> Bool isMakeLensesThSplice (SpliceD _ (SpliceDecl _ (L _ splice) _)) = case splice of HsUntypedSplice _ _ _ (L _ expr) -> case expr of HsApp _ (L _ l) (L _ r) -> case l of HsVar NoExt (L _ (Unqual (occNameString -> "makeLenses"))) -> True _ -> False _ -> False _ -> False isMakeLensesThSplice _ = False getMakeLensesSplices :: HsDecl GhcPs -> [String] getMakeLensesSplices (SpliceD _ (SpliceDecl _ (L _ splice) _)) = case splice of HsUntypedSplice _ _ _ (L _ expr) -> case expr of HsApp _ (L _ l) (L _ r) -> case l of HsVar NoExt (L _ (Unqual (occNameString -> "makeLenses"))) -> case r of HsBracket NoExt (VarBr NoExt False (Unqual (occNameString -> typ))) -> [typ] _ -> [] _ -> [] _ -> [] _ -> [] getMakeLensesSplices _ = [] mkVar :: String -> HsExpr GhcPs mkVar x = HsVar NoExt (mkName x) -- | test -- main :: IO () -- main = do -- Right (_, L _ s) <- parseModule "Main.hs" -- putStrLn $ showSDocUnsafe (showAstData BlankSrcSpan s) -- let n = s & decls %~ concatMap (modifyDecls s) -- putStrLn $ showSDocUnsafe (ppr n) type FieldName = String type TypeName = String genSigD :: FieldName -> TypeName -- ^ Inner type, i.e. "Person" in Lens' Person Int -> HsType GhcPs -- ^ Outer type, i.e. "Int" in Lens' Person Int -> HsDecl GhcPs genSigD fieldName innerType outerType = SigD NoExt (TypeSig NoExt [ mkName fieldName ] hsWc) where hsWc = HsWC NoExt hsIb hsIb = HsIB NoExt (emptyL result) result = tyVarLens `appTy` tyVarTypeInner `appTy` outerType tyVarTypeInner = tyVar innerType appTy :: HsType GhcPs -> HsType GhcPs -> HsType GhcPs appTy = HsAppTy NoExt `on` emptyL tyVarLens :: HsType GhcPs tyVarLens = tyC "Lens'" tyVar :: String -> HsType GhcPs tyVar s = HsTyVar NoExt NotPromoted (mkTyVarName s) tyC :: String -> HsType GhcPs tyC s = HsTyVar NoExt NotPromoted (mkTyCName s) getDecls :: HsModule GhcPs -> [(TypeName, [(FieldName, HsType GhcPs)])] getDecls mod = concatMap go $ fmap (^. located) (hsmodDecls mod) where go :: HsDecl GhcPs -> [(String, [(String,HsType GhcPs)])] go (TyClD NoExt d) = [(getDeclTypeName &&& getFieldAndTypeName) d] go _ = [] mkName :: String -> Located RdrName mkName = emptyL . mkRdrUnqual . mkOccName OccName.varName mkTyVarName :: String -> Located RdrName mkTyVarName = emptyL . mkRdrUnqual . mkOccName OccName.tcName mkTyCName :: String -> Located RdrName mkTyCName = emptyL . mkRdrUnqual . mkOccName OccName.tcName -- | Extract existing type information from a Type or class Decl getDeclTypeName :: TyClDecl GhcPs -> String getDeclTypeName DataDecl {..} = case tcdLName ^. located of Unqual (occNameString -> s) -> s getDeclTypeName _ = mempty -- | Extract field name information from a record getFieldAndTypeName :: TyClDecl GhcPs -> [(String,HsType GhcPs)] getFieldAndTypeName DataDecl {..} = concat . concat $ dd_cons tcdDataDefn <&> \(L _ ConDeclH98 {..}) -> case con_args of RecCon (L _ xs) -> xs <&> \(L _ ConDeclField{..}) -> case cd_fld_names of [ L _ FieldOcc {..} ] -> case rdrNameFieldOcc of L _ (Unqual fieldName) -> pure (occNameString fieldName, cd_fld_type ^. located) _ -> [] _ -> [] _ -> [] getFieldAndTypeName _ = [] genLensCall :: String -> (String, HsType GhcPs) -> [HsDecl GhcPs] genLensCall lensInnerType (fieldName, fieldType) = [ genSigD lensName lensInnerType fieldType, valD (funBind lensName mg) ] where lensName = drop 1 fieldName mg = matchGroup [ match (funRhs lensName) $ grhss [ grhs $ (hsVar "lens" `hsApp` hsVar fieldName) `hsApp` (hsPar $ hsLam $ matchGroup [ lambdaMatch [ varPat "r", varPat "f" ] $ grhss [ grhs $ recordUpd (hsVar "r") [ hsRecUpdField fieldName (hsVar "f") ] ] ] ) ] ] valD :: HsBind GhcPs -> HsDecl GhcPs valD = ValD NoExt funBind :: String -> MatchGroup GhcPs (LHsExpr GhcPs) -> HsBind GhcPs funBind s mg = FunBind NoExt (mkName s) mg WpHole [] matchGroup :: [Match GhcPs (LHsExpr GhcPs)] -> MatchGroup GhcPs (LHsExpr GhcPs) matchGroup xs = MG NoExt (emptyL (fmap emptyL xs)) FromSource match :: HsMatchContext (NameOrRdrName (IdP GhcPs)) -- see funRhs -> GRHSs GhcPs (LHsExpr GhcPs) -> Match GhcPs (LHsExpr GhcPs) match x y = Match NoExt x [] y lambdaMatch :: [LPat GhcPs] -> GRHSs GhcPs (LHsExpr GhcPs) -> Match GhcPs (LHsExpr GhcPs) lambdaMatch xs y = Match NoExt LambdaExpr xs y funRhs :: String -> HsMatchContext (NameOrRdrName (IdP GhcPs)) funRhs x = FunRhs (mkName x) Prefix NoSrcStrict emptyL :: e -> GenLocated SrcSpan e emptyL = L noSrcSpan grhss :: [GRHS GhcPs (LHsExpr GhcPs)] -> GRHSs GhcPs (LHsExpr GhcPs) grhss xs = GRHSs NoExt (fmap emptyL xs) (emptyL (EmptyLocalBinds NoExt)) grhs :: HsExpr GhcPs -> GRHS GhcPs (LHsExpr GhcPs) grhs = GRHS NoExt [] . emptyL hsApp :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs hsApp l r = HsApp NoExt (emptyL l) (emptyL r) hsVar :: String -> HsExpr GhcPs hsVar = HsVar NoExt . mkName hsPar :: HsExpr GhcPs -> HsExpr GhcPs hsPar = HsPar NoExt . emptyL hsLam :: MatchGroup GhcPs (LHsExpr GhcPs) -> HsExpr GhcPs hsLam = HsLam NoExt recordUpd :: HsExpr GhcPs -> [HsRecUpdField GhcPs] -> HsExpr GhcPs recordUpd e fs = RecordUpd NoExt (emptyL e) (emptyL <$> fs) hsRecUpdField :: String -> HsExpr GhcPs -> HsRecUpdField GhcPs hsRecUpdField s e = HsRecField (emptyL (ambig s)) (emptyL e) False where ambig :: String -> AmbiguousFieldOcc GhcPs ambig s = Unambiguous NoExt (mkName s) varPat :: String -> Pat GhcPs varPat = VarPat NoExt . mkName