{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} module Main where import qualified Data.ByteString as B import Data.Int (Int32) import Data.List (sortBy) import Data.Maybe (isJust) import Data.ProtoLens (decodeMessage, defMessage, encodeMessage) import Data.ProtoLens.Compiler.ModuleName (protoModuleName) import Data.ProtoLens.Labels () import qualified Data.Set as Set import Data.String (fromString) import qualified Data.Text as T import Data.Text (Text, intercalate, pack, unpack) import DynFlags (DynFlags, getDynFlags) import GHC (runGhc) import GHC.Paths (libdir) import GHC.SourceGen import GHC.SourceGen.Pretty (showPpr) import GhcMonad (liftIO) import Lens.Family2 import Proto.Google.Protobuf.Compiler.Plugin ( CodeGeneratorRequest, CodeGeneratorResponse, ) import Proto.Google.Protobuf.Descriptor ( DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FieldDescriptorProto'Label (FieldDescriptorProto'LABEL_REPEATED), FieldDescriptorProto'Type (FieldDescriptorProto'TYPE_MESSAGE), FileDescriptorProto, ) import System.Environment (getProgName) import System.Exit (ExitCode (..), exitWith) import qualified System.IO as IO import Text.Casing (camel) data ProtoMod = ProtoMod { modName :: String, modTypes :: [ProtoType] } deriving (Show) data ProtoType = ProtoMsg String DescriptorProto | ProtoEnum String deriving (Show) main :: IO () main = do contents <- B.getContents progName <- getProgName case decodeMessage contents of Left e -> IO.hPutStrLn IO.stderr e >> exitWith (ExitFailure 1) Right x -> runGhc (Just libdir) $ do dflags <- getDynFlags liftIO $ B.putStr $ encodeMessage $ makeResponse dflags progName x makeResponse :: DynFlags -> String -> CodeGeneratorRequest -> CodeGeneratorResponse makeResponse dflags prog req = defMessage & #file .~ [ defMessage & #name .~ "Proto/SignableOrphan.hs" & #content .~ header <> "\n\n" <> body ] where protoMods :: [ProtoMod] protoMods = (\x -> ProtoMod (parseModName x) $ parseModTypes x) <$> req ^. #protoFile imports :: [ImportDecl'] imports = qualified' <$> ( [ import' "Universum", import' "Data.Signable", import' "GHC.List" ] <> ( protoMods >>= ( \x -> let n = modName x in import' . fromString <$> [n, n <> "_Fields"] ) ) ) body :: Text body = pack . showPpr dflags $ module' (Just "Proto.SignableOrphan") (Just []) imports (protoMods >>= mkImpls) header :: Text header = Data.Text.intercalate "\n" $ [ "{- This file was auto-generated by the " <> pack prog <> " program. -}", "{-# OPTIONS_GHC -fno-warn-orphans #-}", "{-# LANGUAGE NoImplicitPrelude #-}" ] parseModName :: FileDescriptorProto -> String parseModName fd = protoModuleName (T.unpack $ fd ^. #name) parseModTypes :: FileDescriptorProto -> [ProtoType] parseModTypes x = (parseEnum mempty <$> x ^. #enumType) <> ((x ^. #messageType) >>= parseMsg mempty) parseEnum :: String -> EnumDescriptorProto -> ProtoType parseEnum ns x = ProtoEnum $ ns <> unpack (x ^. #name) parseMsg :: String -> DescriptorProto -> [ProtoType] parseMsg ns0 x = ProtoMsg n x : (parseEnum ns <$> x ^. #enumType) <> ((x ^. #nestedType) >>= parseMsg ns) where n = ns0 <> unpack (x ^. #name) ns = n <> "'" mkImpls :: ProtoMod -> [HsDecl'] mkImpls x = mk (modName x) <$> modTypes x where mk m = \case ProtoMsg t d -> mkMsgImpl m t d ProtoEnum t -> mkEnumImpl m t mkMsgImpl :: String -> String -> DescriptorProto -> HsDecl' mkMsgImpl m t d = instance' (var "Data.Signable.Signable" @@ var (fromString $ m <> "." <> t)) [ funBind "toBinary" $ match [] ( op (var "Universum.mconcat") compose ( op ( var "Universum.<&>" @@ ( list . (mkMsgChunk m <$>) . sortBy ( \x y -> compare (x ^. #number) (y ^. #number) ) $ d ^. #field ) ) compose (var "Universum.&") ) ) ] mkMsgChunk :: String -> FieldDescriptorProto -> HsExpr' mkMsgChunk m d | d ^. #label == FieldDescriptorProto'LABEL_REPEATED = rExpr | (d ^. #type' == FieldDescriptorProto'TYPE_MESSAGE) || (isJust $ d ^. #maybe'oneofIndex) = mExpr | otherwise = expr where n0 = unReserve . camel . unpack $ d ^. #name tag = case safeFromIntegral $ d ^. #number :: Maybe Int32 of Just v -> var "Data.Signable.toBinary" @@ int (fromIntegral v :: Integer) @::@ var "Universum.Int32" Nothing -> error "TAG_OVERFLOW" rExpr = op ( var "Data.Signable.ifThenElse" @@ (var "GHC.List.null") @@ (var "Universum.const" @@ var "Universum.mempty") @@ ( op (var "Universum.<>" @@ tag) compose (var "Data.Signable.toBinary") ) ) compose (var "Universum.view" @@ var (fromString $ m <> "_Fields." <> n0)) mExpr = op ( var "Universum.maybe" @@ (var "Universum.mempty") @@ ( op (var "Universum.<>" @@ tag) compose (var "Data.Signable.toBinary") ) ) compose (var "Universum.view" @@ var (fromString $ m <> "_Fields.maybe'" <> n0)) expr = op (var "Universum.<>" @@ tag) compose ( op (var "Data.Signable.toBinary") compose (var "Universum.view" @@ var (fromString $ m <> "_Fields." <> n0)) ) mkEnumImpl :: String -> String -> HsDecl' mkEnumImpl m t = instance' (var "Data.Signable.Signable" @@ var (fromString $ m <> "." <> t)) [ funBind "toBinary" $ match [] ( op ( (var "Universum.maybe") @@ (var "Universum.error" @@ string "ENUM_OVERFLOW") @@ (var "Data.Signable.toBinary") ) compose ( op ( (var "Data.Signable.safeFromIntegral") @::@ ( var "Universum.Int" --> var "Universum.Maybe" @@ var "Universum.Int32" ) ) compose (var "Universum.fromEnum") ) ) ] compose :: RdrNameStr compose = "Universum.." unReserve :: String -> String unReserve x = if x `Set.member` reservedKeywords then x <> "'" else x -- | A list of reserved keywords that aren't valid as variable names. reservedKeywords :: Set.Set String reservedKeywords = Set.fromList $ -- Haskell2010 keywords: -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4 -- We don't include keywords that are allowed to be variable names, -- in particular: "as", "forall", and "hiding". [ "case", "class", "data", "default", "deriving", "do", "else", "foreign", "if", "import", "in", "infix", "infixl", "infixr", "instance", "let", "module", "newtype", "of", "then", "type", "where" ] ++ [ "mdo", -- Nonstandard extensions -- RecursiveDo "rec", -- Arrows, RecursiveDo "pattern", -- PatternSynonyms "proc" -- Arrows ] safeFromIntegral :: forall a b. (Integral a, Integral b, Bounded b) => a -> Maybe b safeFromIntegral x = if (intX >= intMin) && (intX <= intMax) then Just $ fromIntegral x else Nothing where intX = fromIntegral x :: Integer intMin = fromIntegral (minBound :: b) :: Integer intMax = fromIntegral (maxBound :: b) :: Integer