{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.ProtoLens.Compiler.Definitions
( Env
, Definition(..)
, MessageInfo(..)
, ServiceInfo(..)
, MethodInfo(..)
, PlainFieldInfo(..)
, FieldInfo(..)
, FieldKind(..)
, FieldPacking(..)
, MapEntryInfo(..)
, OneofInfo(..)
, OneofCase(..)
, FieldName(..)
, Symbol
, nameFromSymbol
, promoteSymbol
, EnumInfo(..)
, EnumValueInfo(..)
, EnumUnrecognizedInfo(..)
, qualifyEnv
, unqualifyEnv
, collectDefinitions
, collectServices
, definedFieldType
, definedType
, camelCase
, overloadedFieldName
) where
import Control.Applicative (liftA2)
import Data.Char (isUpper, toUpper)
import Data.Int (Int32)
import Data.List (mapAccumL)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
import qualified Data.Semigroup as Semigroup
import qualified Data.Set as Set
import Data.String (IsString(..))
import Data.Text (Text, cons, splitOn, toLower, uncons, unpack)
import qualified Data.Text as T
import Data.Tree
( Tree(..)
, Forest
, flatten
)
import Lens.Family2 ((^.), (^..), toListOf)
import Proto.Google.Protobuf.Descriptor
( DescriptorProto
, EnumDescriptorProto
, EnumValueDescriptorProto
, FieldDescriptorProto
, FieldDescriptorProto'Label(..)
, FieldDescriptorProto'Type(..)
, FileDescriptorProto
, MethodDescriptorProto
, ServiceDescriptorProto
)
import Proto.Google.Protobuf.Descriptor_Fields
( clientStreaming
, enumType
, field
, inputType
, label
, mapEntry
, maybe'oneofIndex
, maybe'packed
, messageType
, method
, name
, nestedType
, number
, oneofDecl
, options
, outputType
, package
, serverStreaming
, service
, syntax
, type'
, typeName
, value
)
import Data.ProtoLens.Compiler.Combinators
( Name
, QName
, ModuleName
, Type
, qual
, tyPromotedString
, unQual
)
type Env n = Map.Map Text (Definition n)
data SyntaxType = Proto2 | Proto3
deriving (Show, Eq)
fileSyntaxType :: FileDescriptorProto -> SyntaxType
fileSyntaxType f = case f ^. syntax of
"proto2" -> Proto2
"proto3" -> Proto3
"" -> Proto2
s -> error $ "Unknown syntax type " ++ show s
data Definition n = Message (MessageInfo n) | Enum (EnumInfo n)
deriving Functor
data MessageInfo n = MessageInfo
{ messageName :: n
, messageDescriptor :: DescriptorProto
, messageFields :: [PlainFieldInfo]
, messageOneofFields :: [OneofInfo]
, messageUnknownFields :: Name
, groupFieldNumber :: Maybe Int32
} deriving Functor
data ServiceInfo = ServiceInfo
{ serviceName :: Text
, servicePackage :: Text
, serviceMethods :: [MethodInfo]
}
data MethodInfo = MethodInfo
{ methodName :: Text
, methodIdent :: Text
, methodInput :: Text
, methodOutput :: Text
, methodClientStreaming :: Bool
, methodServerStreaming :: Bool
}
data PlainFieldInfo = PlainFieldInfo
{ plainFieldKind :: FieldKind
, plainFieldInfo :: FieldInfo
}
data FieldInfo = FieldInfo
{ fieldDescriptor :: FieldDescriptorProto
, fieldName :: FieldName
}
data FieldKind
= RequiredField
| OptionalValueField
| OptionalMaybeField
| RepeatedField FieldPacking
| MapField MapEntryInfo
data FieldPacking
= NotPackable
| Packable
| Packed
deriving Eq
data OneofInfo = OneofInfo
{ oneofFieldName :: FieldName
, oneofTypeName :: Name
, oneofCases :: [OneofCase]
}
data OneofCase = OneofCase
{ caseField :: FieldInfo
, caseConstructorName :: Name
, casePrismName :: Name
}
data MapEntryInfo = MapEntryInfo
{ mapEntryTypeName :: Name
, keyField :: FieldInfo
, valueField :: FieldInfo
}
data FieldName = FieldName
{ overloadedName :: Symbol
, haskellRecordFieldName :: Name
}
newtype Symbol = Symbol String
deriving (Eq, Ord, IsString, Semigroup.Semigroup, Monoid)
nameFromSymbol :: Symbol -> Name
nameFromSymbol (Symbol s) = fromString s
promoteSymbol :: Symbol -> Type
promoteSymbol (Symbol s) = tyPromotedString s
data EnumInfo n = EnumInfo
{ enumName :: n
, enumUnrecognized :: Maybe EnumUnrecognizedInfo
, enumDescriptor :: EnumDescriptorProto
, enumValues :: [EnumValueInfo n]
} deriving Functor
data EnumUnrecognizedInfo = EnumUnrecognizedInfo
{ unrecognizedName :: Name
, unrecognizedValueName :: Name
}
data EnumValueInfo n = EnumValueInfo
{ enumValueName :: n
, enumValueDescriptor :: EnumValueDescriptorProto
, enumAliasOf :: Maybe Name
} deriving Functor
mapEnv :: (n -> n') -> Env n -> Env n'
mapEnv f = fmap $ fmap f
qualifyEnv :: ModuleName -> Env Name -> Env QName
qualifyEnv m = mapEnv (qual m)
unqualifyEnv :: Env Name -> Env QName
unqualifyEnv = mapEnv unQual
definedFieldType :: FieldDescriptorProto -> Env QName -> Definition QName
definedFieldType fd env = fromMaybe err $ Map.lookup (fd ^. typeName) env
where
err = error $ "definedFieldType: Field type " ++ unpack (fd ^. typeName)
++ " not found in environment."
definedType :: Text -> Env QName -> Definition QName
definedType ty = fromMaybe err . Map.lookup ty
where
err = error $ "definedType: Type " ++ unpack ty
++ " not found in environment."
collectDefinitions :: FileDescriptorProto -> Env Name
collectDefinitions fd = let
protoPrefix = case fd ^. package of
"" -> "."
p -> "." <> p <> "."
hsPrefix = ""
in Map.fromList $ concatMap flatten $
messageAndEnumDefs (fileSyntaxType fd)
protoPrefix hsPrefix Map.empty
(fd ^. messageType) (fd ^. enumType)
collectServices :: FileDescriptorProto -> [ServiceInfo]
collectServices fd = fmap (toServiceInfo $ fd ^. package) $ fd ^. service
where
toServiceInfo :: Text -> ServiceDescriptorProto -> ServiceInfo
toServiceInfo pkg sd =
ServiceInfo
{ serviceName = sd ^. name
, servicePackage = pkg
, serviceMethods = fmap toMethodInfo $ sd ^. method
}
toMethodInfo :: MethodDescriptorProto -> MethodInfo
toMethodInfo md =
MethodInfo
{ methodName = md ^. name
, methodIdent = camelCase $ md ^. name
, methodInput = fromString . T.unpack $ md ^. inputType
, methodOutput = fromString . T.unpack $ md ^. outputType
, methodClientStreaming = md ^. clientStreaming
, methodServerStreaming = md ^. serverStreaming
}
messageAndEnumDefs ::
SyntaxType -> Text -> String
-> GroupMap
-> [DescriptorProto]
-> [EnumDescriptorProto]
-> Forest (Text, Definition Name)
messageAndEnumDefs syntaxType protoPrefix hsPrefix groups messages enums
= map (messageDefs syntaxType protoPrefix hsPrefix groups) messages
++ map
(flip Node []
. enumDef syntaxType protoPrefix hsPrefix)
enums
messageDefs :: SyntaxType -> Text -> String -> GroupMap -> DescriptorProto
-> Tree (Text, Definition Name)
messageDefs syntaxType protoPrefix hsPrefix groups d
= Node (protoName, thisDef) subDefs
where
protoName = protoPrefix <> d ^. name
hsPrefix' = hsPrefix ++ hsName (d ^. name) ++ "'"
allFields = groupFieldsByOneofIndex (d ^. field)
thisDef =
Message MessageInfo
{ messageName = fromString $ hsPrefix ++ hsName (d ^. name)
, messageDescriptor = d
, messageFields =
map (liftA2 PlainFieldInfo
(fieldKind syntaxType mapEntries) (fieldInfo hsPrefix'))
$ Map.findWithDefault [] Nothing allFields
, messageOneofFields = collectOneofFields hsPrefix' d allFields
, messageUnknownFields =
fromString $ "_" ++ hsPrefix' ++ "_unknownFields"
, groupFieldNumber = Map.lookup protoName groups
}
subDefs = messageAndEnumDefs
syntaxType
(protoName <> ".")
hsPrefix'
(collectGroupFields $ d ^. field)
(d ^. nestedType)
(d ^. enumType)
mapEntries = collectMapEntries $ map rootLabel subDefs
mapEntryInfo :: Definition Name -> Maybe MapEntryInfo
mapEntryInfo (Message m)
| messageDescriptor m ^. options . mapEntry
, [keyFd, valueFd] <- messageFields m
= Just MapEntryInfo
{ mapEntryTypeName = messageName m
, keyField = plainFieldInfo keyFd
, valueField = plainFieldInfo valueFd
}
mapEntryInfo _ = Nothing
collectMapEntries :: [(Text, Definition Name)] -> Map.Map Text MapEntryInfo
collectMapEntries defs =
Map.fromList
[(protoName, e) | (protoName, d) <- defs, Just e <- [mapEntryInfo d]]
type GroupMap = Map.Map Text Int32
collectGroupFields :: [FieldDescriptorProto] -> GroupMap
collectGroupFields fs = Map.fromList
[ (f ^. typeName, f ^. number)
| f <- fs
, f ^. type' == FieldDescriptorProto'TYPE_GROUP
]
fieldInfo :: String -> FieldDescriptorProto -> FieldInfo
fieldInfo hsPrefix f = FieldInfo
{ fieldDescriptor = f
, fieldName = mkFieldName hsPrefix $ f ^. name
}
fieldKind ::
SyntaxType -> Map.Map Text MapEntryInfo -> FieldDescriptorProto
-> FieldKind
fieldKind syntaxType mapEntries f = case f ^. label of
FieldDescriptorProto'LABEL_OPTIONAL
| syntaxType == Proto3
&& f ^. type' /= FieldDescriptorProto'TYPE_MESSAGE
-> OptionalValueField
| otherwise -> OptionalMaybeField
FieldDescriptorProto'LABEL_REQUIRED -> RequiredField
FieldDescriptorProto'LABEL_REPEATED
| Just entryInfo <- Map.lookup (f ^. typeName) mapEntries
-> MapField entryInfo
| otherwise -> RepeatedField packed
where
packed
| f ^. type' `elem` unpackableTypes = NotPackable
| packedByDefault = Packed
| otherwise = Packable
packedByDefault = fromMaybe (syntaxType == Proto3)
$ f ^. options . maybe'packed
unpackableTypes =
[ FieldDescriptorProto'TYPE_MESSAGE
, FieldDescriptorProto'TYPE_GROUP
, FieldDescriptorProto'TYPE_STRING
, FieldDescriptorProto'TYPE_BYTES
]
collectOneofFields
:: String -> DescriptorProto -> Map.Map (Maybe Int32) [FieldDescriptorProto]
-> [OneofInfo]
collectOneofFields hsPrefix d allFields
= zipWith oneofInfo [0..] $ d ^.. oneofDecl . traverse . name
where
oneofInfo idx n = OneofInfo
{ oneofFieldName = mkFieldName hsPrefix n
, oneofTypeName = fromString $ hsPrefix ++ hsNameUnique subdefTypes n
, oneofCases = map oneofCase
$ Map.findWithDefault [] (Just idx)
allFields
}
oneofCase f =
let consName = hsPrefix ++ hsNameUnique subdefCons (f ^. name)
in OneofCase
{ caseField = fieldInfo hsPrefix f
, caseConstructorName =
fromString consName
, casePrismName =
fromString $ "_" ++ consName
}
hsNameUnique ns n
| n' `elem` ns = n' ++ "'"
| otherwise = n'
where
n' = hsName $ camelCase n
subdefTypes = Set.fromList $ map hsName
$ toListOf (nestedType . traverse . name) d
++ toListOf (enumType . traverse . name) d
subdefCons = Set.fromList $ map hsName
$ toListOf (nestedType . traverse . name) d
++ toListOf (enumType . traverse . value . traverse . name) d
groupFieldsByOneofIndex
:: [FieldDescriptorProto] -> Map.Map (Maybe Int32) [FieldDescriptorProto]
groupFieldsByOneofIndex =
fmap reverse
. Map.fromListWith (++)
. fmap (\f -> (f ^. maybe'oneofIndex, [f]))
hsName :: Text -> String
hsName = unpack . capitalize
mkFieldName :: String -> Text -> FieldName
mkFieldName hsPrefix n = FieldName
{ overloadedName = fromString n'
, haskellRecordFieldName = fromString $ "_" ++ hsPrefix ++ n'
}
where
n' = fieldBaseName n
fieldBaseName :: Text -> String
fieldBaseName = unpack . disambiguate . camelCase
where
disambiguate s
| s `Set.member` reservedKeywords = s <> "'"
| otherwise = s
camelCase :: Text -> Text
camelCase s =
let (underlines, rest) = T.span (== '_') s
in case splitOn "_" rest of
[] -> error $ "camelCase: splitOn returned empty list: "
++ show rest
[""] -> error $ "camelCase: name consists only of underscores: "
++ show s
s':ss -> T.concat $ underlines : lowerInitialChars s' : map capitalize ss
lowerInitialChars :: Text -> Text
lowerInitialChars s = toLower pre <> post
where (pre, post) = T.span isUpper s
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++
[ "mdo"
, "rec"
, "pattern"
, "proc"
]
enumDef :: SyntaxType -> Text -> String -> EnumDescriptorProto
-> (Text, Definition Name)
enumDef syntaxType protoPrefix hsPrefix d = let
mkText n = protoPrefix <> n
mkHsName n = fromString $ hsPrefix ++ case hsName n of
('_':xs) -> 'X':xs
xs -> xs
in (mkText (d ^. name)
, Enum EnumInfo
{ enumName = mkHsName (d ^. name)
, enumUnrecognized = if syntaxType == Proto2
then Nothing
else Just EnumUnrecognizedInfo
{ unrecognizedName
= mkHsName (d ^. name <> "'Unrecognized")
, unrecognizedValueName
= mkHsName (d ^. name <> "'UnrecognizedValue")
}
, enumDescriptor = d
, enumValues = collectEnumValues mkHsName $ d ^. value
})
collectEnumValues :: (Text -> Name) -> [EnumValueDescriptorProto]
-> [EnumValueInfo Name]
collectEnumValues mkHsName = snd . mapAccumL helper Map.empty
where
helper :: Map.Map Int32 Name -> EnumValueDescriptorProto
-> (Map.Map Int32 Name, EnumValueInfo Name)
helper seenNames v
| Just n' <- Map.lookup k seenNames = (seenNames, mkValue (Just n'))
| otherwise = (Map.insert k n seenNames, mkValue Nothing)
where
mkValue = EnumValueInfo n v
n = mkHsName (v ^. name)
k = v ^. number
capitalize :: Text -> Text
capitalize s
| Just (c, s') <- uncons s = cons (toUpper c) s'
| otherwise = s
overloadedFieldName :: FieldInfo -> Symbol
overloadedFieldName = overloadedName . fieldName