{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Backend.Rust ( generateRustCode, ) where import qualified Convex.Action.Parser as Action import qualified Convex.Parser as P import qualified Convex.Schema.Parser as Schema import Data.Char (isUpper, toLower, toUpper) import Data.List (intercalate, isInfixOf, isPrefixOf, nub, stripPrefix) import qualified Data.Map as Map import PathTree -- Helper function to prepend a given number of spaces (4 per level). indent :: Int -> String -> String indent n s = replicate (n * 4) ' ' ++ s -- | Top-level function to generate the complete Rust module source. generateRustCode :: P.ParsedProject -> String generateRustCode project = unlines [ "#![allow(dead_code)]", "#![allow(non_snake_case)]", "// Generated by the Palaba code generator. DO NOT EDIT.", "// Save this file as, for example, `src/convex_api.rs`", "// and then add `pub mod convex_api;` to your `src/lib.rs` or `src/main.rs`.", "//", "// Make sure your `Cargo.toml` contains the following dependencies:", "// convex = \"0.1.3\"", "// serde = { version = \"1.0\", features = [\"derive\"] }", "// serde_json = \"1.0\"", "// thiserror = \"1.0\"", "// anyhow = \"1.0\"", "// futures-util = \"0.3\"", "", generateRustModuleContent project ] -- | Generates the entire content for a single Rust module file. generateRustModuleContent :: P.ParsedProject -> String generateRustModuleContent project = let (apiClassCode, nestedFromFuncs) = generateApiClass (P.ppFunctions project) in unlines [ "use convex::{ConvexClient, FunctionResult, Value};", "use futures_util::stream::Stream;", "use serde::{Deserialize, Deserializer, Serialize, Serializer};", "use serde_json;", "use std::collections::BTreeMap;", "use std::fmt::{self, Display};", "use std::marker::PhantomData;", "use std::pin::Pin;", "use std::task::{Context, Poll};", "", stripNewlines generateErrorEnum, "", stripNewlines generateIdStruct, "", stripNewlines generateFromConvexValueBoilerplate, "", stripNewlines generateSubscriptionBoilerplate, "", stripNewlines apiClassCode, -- API class and all submodules "", stripNewlines $ generateTypesModule project (nub nestedFromFuncs) ] -- | Generates a Rust error enum using `thiserror`. generateErrorEnum :: String generateErrorEnum = unlines [ "/// Represents all possible errors that can occur when interacting with the API.", "#[derive(thiserror::Error, Debug)]", "pub enum ApiError {", indent 1 "#[error(\"Convex client error: {0}\")]", indent 1 "ConvexClientError(String),", "", indent 1 "#[error(\"Convex function error: {0}\")]", indent 1 "ConvexFunctionError(String),", "", indent 1 "#[error(\"Failed to deserialize response: {0}\")]", indent 1 "DeserializationError(#[from] serde_json::Error),", "", indent 1 "#[error(\"Unexpected null value returned from a non-nullable function\")]", indent 1 "UnexpectedNullError,", "}", "", "impl From for ApiError {", indent 1 "fn from(err: anyhow::Error) -> Self {", indent 2 "ApiError::ConvexClientError(err.to_string())", indent 1 "}", "}" ] -- | Generates the strongly-typed `Id` struct. generateIdStruct :: String generateIdStruct = unlines [ "/// A strongly-typed Convex document ID.", "#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]", "pub struct Id {", indent 1 "id: String,", indent 1 "_phantom: PhantomData,", "}", "", "impl Default for Id {", indent 1 "fn default() -> Self {", indent 2 "Self { id: String::new(), _phantom: PhantomData }", indent 1 "}", "}", "", "impl Id {", indent 1 "pub fn new(id: String) -> Self {", indent 2 "Self { id, _phantom: PhantomData }", indent 1 "}", "}", "", "impl Display for Id {", indent 1 "fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {", indent 2 "write!(f, \"{}\", self.id)", indent 1 "}", "}", "", "impl Serialize for Id {", indent 1 "fn serialize(&self, serializer: S) -> Result", indent 1 "where", indent 2 "S: Serializer,", indent 1 "{", indent 2 "serializer.serialize_str(&self.id)", indent 1 "}", "}", "", "impl<'de, T> Deserialize<'de> for Id {", indent 1 "fn deserialize(deserializer: D) -> Result", indent 1 "where", indent 2 "D: Deserializer<'de>,", indent 1 "{", indent 2 "let id = String::deserialize(deserializer)?;", indent 2 "Ok(Id::new(id))", indent 1 "}", "}", "", "impl From> for Value {", indent 1 "fn from(val: Id) -> Self {", indent 2 "Value::String(val.id)", indent 1 "}", "}", "impl TryFrom for Id {", indent 1 "type Error = ApiError;", "", indent 1 "fn try_from(value: Value) -> Result {", indent 2 "if let Value::String(id) = value {", indent 3 "Ok(Id::new(id))", indent 2 "} else {", indent 3 "Err(ApiError::ConvexClientError(", indent 4 "\"Expected a string for Id\".to_string(),", indent 3 "))", indent 2 "}", indent 1 "}", "}", "", indent 0 "impl FromConvexValue for Id {", indent 1 "fn from_convex(value: Value) -> Result {", indent 2 "Id::try_from(value).map_err(ApiError::from)", indent 1 "}", indent 0 "}" ] -- | Generates the generic TypedSubscription struct and its Stream implementation. generateSubscriptionBoilerplate :: String generateSubscriptionBoilerplate = unlines [ "/// A type-safe, auto-deserializing stream of updates from a Convex query subscription.", "#[derive(Debug)]", "pub struct TypedSubscription {", indent 1 "raw_subscription: convex::QuerySubscription,", indent 1 "_phantom: PhantomData,", "}", "", "impl Stream for TypedSubscription", "where", indent 1 "T: FromConvexValue,", "{", indent 1 "type Item = Result;", "", indent 1 "fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {", indent 2 "let raw_sub_pin: Pin<&mut _> = unsafe { self.map_unchecked_mut(|s| &mut s.raw_subscription) };", indent 2 "match raw_sub_pin.poll_next(cx) {", indent 3 "Poll::Ready(Some(result)) => {", indent 4 "let item = match result {", indent 5 "FunctionResult::Value(value) => T::from_convex(value),", indent 5 "FunctionResult::ErrorMessage(s) => Err(ApiError::ConvexFunctionError(s)),", indent 5 "FunctionResult::ConvexError(err) => Err(ApiError::ConvexClientError(err.to_string())),", indent 4 "};", indent 4 "Poll::Ready(Some(item))", indent 3 "}", indent 3 "Poll::Ready(None) => Poll::Ready(None),", indent 3 "Poll::Pending => Poll::Pending,", indent 2 "}", indent 1 "}", "}" ] generateFromConvexValueBoilerplate :: String generateFromConvexValueBoilerplate = unlines [ "pub trait FromConvexValue: Sized {", " fn from_convex(value: Value) -> Result;", "}", "", "impl FromConvexValue for bool {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Boolean(b) => Ok(b),", " _ => Err(ApiError::ConvexClientError(\"Expected bool\".into())),", " }", " }", "}", "", "impl FromConvexValue for i64 {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Int64(i) => Ok(i),", " _ => Err(ApiError::ConvexClientError(\"Expected i64\".into())),", " }", " }", "}", "", "impl FromConvexValue for i32 {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Int64(i) => i", " .try_into()", " .map_err(|_| ApiError::ConvexClientError(\"i64 out of range for i32\".into())),", " _ => Err(ApiError::ConvexClientError(\"Expected i64\".into())),", " }", " }", "}", "", "impl FromConvexValue for f64 {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Float64(f) => Ok(f),", " _ => Err(ApiError::ConvexClientError(\"Expected f64\".into())),", " }", " }", "}", "", "impl FromConvexValue for f32 {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Float64(f) => Ok(f as f32),", " _ => Err(ApiError::ConvexClientError(\"Expected f64\".into())),", " }", " }", "}", "", "impl FromConvexValue for Vec", "where", " T: FromConvexValue,", "{", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Array(arr) => arr", " .into_iter()", " .map(T::from_convex)", " .collect::, ApiError>>(),", " _ => Err(ApiError::ConvexClientError(\"Expected array\".into())),", " }", " }", "}", "", "impl FromConvexValue for Option", "where", " T: FromConvexValue,", "{", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Null => Ok(None),", " other => T::from_convex(other).map(Some),", " }", " }", "}", "", "impl FromConvexValue for Vec {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::Bytes(bytes) => Ok(bytes),", " _ => Err(ApiError::ConvexClientError(\"Expected bytes\".into())),", " }", " }", "}", "", "impl FromConvexValue for String {", " fn from_convex(value: Value) -> Result {", " match value {", " Value::String(s) => Ok(s),", " _ => Err(ApiError::ConvexClientError(\"Expected string\".into())),", " }", " }", "}", "", "impl FromConvexValue for serde_json::Value {", " fn from_convex(value: Value) -> Result {", " Ok(value.into())", " }", "}", "" ] generateApiClass :: [Action.ConvexFunction] -> (String, [String]) generateApiClass funcs = let tree = buildPathTree funcs (structDefs, implDef, nested) = generateApiStructure "Api" tree in ( unlines [ "pub struct Api {", indent 1 "pub client: ConvexClient,", "}", structDefs, implDef ], nested ) generateApiStructure :: String -> PathTree -> (String, String, [String]) generateApiStructure parentName (DirNode dirMap) = let (structs, impls, nested) = unzip3 $ map (uncurry processEntry) (Map.toList dirMap) (accessors, functions) = partitionEntries (Map.toList dirMap) in ( unlines structs, unlines [ "impl" ++ (if parentName == "Api" then "" else "<'a>") ++ " " ++ parentName ++ (if parentName == "Api" then "" else "<'a>") ++ " {", if parentName == "Api" then indent 1 "pub fn new(client: ConvexClient) -> Self {\n Self { client }\n }" else "", unlines (map generateAccessorMethod accessors), unlines (concatMap generateMethodsForEntry functions), "}" ] ++ unlines impls, concat nested ) where processEntry name (DirNode subDir) = let structName = toPascalCase name (subStructs, subImpls, nestedFromSub) = generateApiStructure structName (DirNode subDir) structDef = unlines [ "pub struct " ++ structName ++ "<'a> {", indent 1 "client: &'a mut ConvexClient,", "}" ] in (unlines [structDef, subStructs], subImpls, nestedFromSub) processEntry _ (FuncNode func) = let (_, nestedFromFunc) = generateFunction func (_, nestedFromSub) = if Action.funcType func == Action.Query then generateSubscriptionFunction func else ("", []) in ("", "", nestedFromFunc ++ nestedFromSub) partitionEntries = foldl ( \(ds, fs) (name, node) -> case node of DirNode _ -> ((name, node) : ds, fs) FuncNode _ -> (ds, (name, node) : fs) ) ([], []) generateAccessorMethod (name, _) = let structName = toPascalCase name methodName = toSnakeCase name in unlines [ indent 1 ("pub fn " ++ methodName ++ "(&mut self) -> " ++ structName ++ "<'_> {"), indent 2 (structName ++ " { client: &mut self.client }"), indent 1 "}" ] generateMethodsForEntry (_, FuncNode func) = let (queryDef, _) = generateFunction func (subDef, _) = if Action.funcType func == Action.Query then generateSubscriptionFunction func else ("", []) in [queryDef, subDef] generateMethodsForEntry _ = [] generateApiStructure _ _ = ("", "", []) generateFunction :: Action.ConvexFunction -> (String, [String]) generateFunction func = let funcName = Action.funcName func args = Action.funcArgs func fullFuncPath = Action.funcPath func ++ ":" ++ funcName (argSignature, nestedFromArgs) = generateArgSignatureStruct fullFuncPath args funcNameSnake = toSnakeCase funcName (returnHint, isNullable, nestedFromReturn) = getReturnType funcName (Action.funcReturn func) handlerCall = case Action.funcType func of Action.Query -> "query" Action.Mutation -> "mutation" Action.Action -> "action" btreemapConstruction = generateBTreeMap (Action.funcArgs func) returnHandling = generateReturnHandling returnHint isNullable funcCode = case args of [] -> unlines [ indent 1 ("/// Wraps the `" ++ fullFuncPath ++ "` " ++ show (Action.funcType func) ++ "."), indent 1 ("pub async fn " ++ funcNameSnake ++ "(&mut self) -> Result<" ++ returnHint ++ ", ApiError> {"), btreemapConstruction, indent 2 ("let result = self.client." ++ handlerCall ++ "(\"" ++ fullFuncPath ++ "\", btmap).await?;"), returnHandling, indent 1 "}" ] _ -> unlines [ indent 1 ("/// Wraps the `" ++ fullFuncPath ++ "` " ++ show (Action.funcType func) ++ "."), indent 1 ("pub async fn " ++ funcNameSnake ++ "(&mut self, arg: " ++ argSignature ++ ") -> Result<" ++ returnHint ++ ", ApiError> {"), btreemapConstruction, indent 2 ("let result = self.client." ++ handlerCall ++ "(\"" ++ fullFuncPath ++ "\", btmap).await?;"), returnHandling, indent 1 "}" ] in (funcCode, nestedFromArgs ++ nestedFromReturn) generateSubscriptionFunction :: Action.ConvexFunction -> (String, [String]) generateSubscriptionFunction func = let funcName = Action.funcName func args = Action.funcArgs func (argSignature, nestedFromArgs) = generateArgSignatureStruct fullFuncPath args funcNameSnake = "subscribe_" ++ toSnakeCase funcName (returnHint, _, nestedFromReturn) = getReturnType funcName (Action.funcReturn func) fullFuncPath = Action.funcPath func ++ ":" ++ funcName btreemapConstruction = generateBTreeMap (Action.funcArgs func) funcCode = case args of [] -> unlines [ indent 1 ("/// Subscribes to the `" ++ fullFuncPath ++ "` query."), indent 1 ("pub async fn " ++ funcNameSnake ++ "(&mut self) -> Result, ApiError> {"), btreemapConstruction, indent 2 ("let raw_subscription = self.client.subscribe(\"" ++ fullFuncPath ++ "\", btmap).await?;"), indent 2 "Ok(TypedSubscription {", indent 3 "raw_subscription,", indent 3 "_phantom: PhantomData,", indent 2 "})", indent 1 "}" ] _ -> unlines [ indent 1 ("/// Subscribes to the `" ++ fullFuncPath ++ "` query."), indent 1 ("pub async fn " ++ funcNameSnake ++ "(&mut self, arg: " ++ argSignature ++ ") -> Result, ApiError> {"), btreemapConstruction, indent 2 ("let raw_subscription = self.client.subscribe(\"" ++ fullFuncPath ++ "\", btmap).await?;"), indent 2 "Ok(TypedSubscription {", indent 3 "raw_subscription,", indent 3 "_phantom: PhantomData,", indent 2 "})", indent 1 "}" ] in (funcCode, nestedFromArgs ++ nestedFromReturn) generateTypesModule :: P.ParsedProject -> [String] -> String generateTypesModule project nestedFromFuncs = let (tableCode, nestedFromTables) = generateAllTables (P.ppSchema project) (constantsCode, nestedFromConstants) = generateAllConstants (P.ppConstants project) allNestedCode = nub (nestedFromTables ++ nestedFromConstants ++ nestedFromFuncs) in unlines [ "pub mod types {", indent 1 "use super::*;", "", tableCode, constantsCode, unlines allNestedCode, "}" ] generateAllTables :: Schema.Schema -> (String, [String]) generateAllTables (Schema.Schema tables) = let (tableCodes, nested) = unzip $ map generateTableStruct tables in (unlines tableCodes, concat nested) generateTableStruct :: Schema.Table -> (String, [String]) generateTableStruct table = let className = toPascalCase (Schema.tableName table) ++ "Doc" (fieldLines, nestedFromFields) = unzip $ map (generateField className) (Schema.tableFields table) allFields = [ ("_id", Schema.VId (toPascalCase (Schema.tableName table))), ("_creation_time", Schema.VFloat64) ] ++ map (\f -> (Schema.fieldName f, Schema.fieldType f)) (Schema.tableFields table) fromBlock = generateFromConvexValueImpl (className) allFields toBlock = generateToConvexValueImpl (className) allFields in ( unlines [ "#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]", ("pub struct " ++ className ++ " {"), indent 1 "#[serde(default)]", indent 1 ("pub _id: Id<" ++ className ++ ">,"), indent 1 "#[serde(default)]", indent 1 "#[serde(rename = \"_creationTime\")]", indent 1 "pub _creation_time: f64,", unlines fieldLines, "}", "", fromBlock, "", toBlock ], concat nestedFromFields ) generateAllConstants :: Map.Map String Schema.ConvexType -> (String, [String]) generateAllConstants constants = let (constCodes, nested) = unzip $ map (uncurry generateConstant) (Map.toList constants) in (unlines constCodes, concat nested) generateConstant :: String -> Schema.ConvexType -> (String, [String]) generateConstant name u@(Schema.VUnion literals) | all Schema.isLiteral literals = let enumName = toPascalCase name enumFromConvexValueImpl = generateFromConvexValueImplEnum ("types::" ++ enumName) literals variantNames = map (\l -> let n = Schema.getLiteralString l in (Schema.sanitizeUnionValues n, n)) literals buildVariantLines [] = [] buildVariantLines ((sanitizedFirst, originalFirst) : rest) = (indent 2 "#[default]\n" ++ indent 2 ("#[serde(rename = \"" ++ originalFirst ++ "\")]\n") ++ indent 2 ((toPascalCase sanitizedFirst) ++ ",")) : map (\(sanitizedV, originalV) -> indent 2 ("#[serde(rename = \"" ++ originalV ++ "\")]\n") ++ indent 2 (toPascalCase sanitizedV ++ ",")) rest code = unlines [ indent 1 "#[derive(Default, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]", indent 1 ("pub enum " ++ enumName ++ " {"), unlines $ buildVariantLines variantNames, indent 1 "}", "", enumFromConvexValueImpl, "" ] in (code, []) | otherwise = let (rustTypeName, nested) = toRustType name u in (indent 1 ("pub type " ++ toPascalCase name ++ " = " ++ rustTypeName ++ ";"), nested) generateConstant name t = let (rustTypeName, nested) = toRustType name t in (indent 1 ("pub type " ++ toPascalCase name ++ " = " ++ rustTypeName ++ ";"), nested) generateField :: String -> Schema.Field -> (String, [String]) generateField nameHint field = let fieldNameSnake = toSnakeCase (Schema.fieldName field) (rustType, nested) = toRustType (nameHint ++ capitalize (Schema.fieldName field)) (Schema.fieldType field) serdeRename = if fieldNameSnake /= Schema.fieldName field then indent 2 ("#[serde(rename = \"" ++ Schema.fieldName field ++ "\")]\n") else "" serdeAttrs = let defaultAttr = "default" skipAttr = if needsOptionalWrapper (Schema.fieldType field) then Just "skip_serializing_if = \"Option::is_none\"" else Nothing allAttrs = case skipAttr of Just s -> [defaultAttr, s] Nothing -> [defaultAttr] in indent 2 ("#[serde(" ++ intercalate ", " allAttrs ++ ")]") fieldLine = serdeRename ++ serdeAttrs ++ "\n" ++ indent 2 ("pub " ++ fieldNameSnake ++ ": " ++ rustType ++ ",") in (fieldLine, nested) generateArgSignatureStruct :: String -> [(String, Schema.ConvexType)] -> (String, [String]) generateArgSignatureStruct _ [] = ("", []) generateArgSignatureStruct fullFuncName args = let argStructName = toPascalCase $ map (\c -> if (c == '/' || c == ':') then '_' else c) $ fullFuncName ++ "Arg" (argCode, subArgs) = generateConstant argStructName $ Schema.VObject args in ("types::" ++ argStructName ++ "Object", argCode : subArgs) generateBTreeMap :: [(String, Schema.ConvexType)] -> String generateBTreeMap [] = indent 2 "let btmap = BTreeMap::new();" generateBTreeMap btmap = let buildStmts (name, convexType) = let varName = "arg." ++ toSnakeCase name in case convexType of Schema.VObject _ -> indent 1 ("btmap.insert(\"" ++ name ++ "\".to_string(), " ++ varName ++ ".to_convex_value()?);") Schema.VOptional innerConvexType -> indent 1 ("if let Some(v) = " ++ varName ++ " { btmap.insert(\"" ++ name ++ "\".to_string(), " ++ fieldToConvexValue ("v", innerConvexType) ++ "); }") _ -> indent 1 ("btmap.insert(\"" ++ name ++ "\".to_string(), " ++ fieldToConvexValue (varName, convexType) ++ ");") in unlines [ indent 2 "let mut btmap = BTreeMap::new();", unlines $ map buildStmts btmap ] fieldToConvexValue :: (String, Schema.ConvexType) -> String fieldToConvexValue (fieldName, t) = let fieldNameSnake = toSnakeCase fieldName valueExpr = innerValueToConvexNonOptional fieldNameSnake t in valueExpr toClonedValue :: String -> Schema.ConvexType -> String toClonedValue varName (Schema.VString) = varName ++ ".to_string()" toClonedValue varName t | isPassedByCopy t = "*" ++ varName | otherwise = varName ++ ".clone()" isPassedByCopy :: Schema.ConvexType -> Bool isPassedByCopy Schema.VNumber = True isPassedByCopy Schema.VInt64 = True isPassedByCopy Schema.VFloat64 = True isPassedByCopy Schema.VBoolean = True isPassedByCopy _ = False getReturnType :: String -> Schema.ConvexType -> (String, Bool, [String]) getReturnType funcName rt = let (baseType, nested) = toRustType (funcName ++ "Return") rt isNullable = needsOptionalWrapper rt in if baseType == "()" then ("()", False, nested) else (baseType, isNullable, nested) generateReturnHandling :: String -> Bool -> String generateReturnHandling "()" _ = unlines [ indent 2 "match result {", indent 3 "FunctionResult::Value(_) => Ok(()),", indent 3 "FunctionResult::ErrorMessage(s) => Err(ApiError::ConvexFunctionError(s)),", indent 3 "FunctionResult::ConvexError(err) => Err(ApiError::ConvexClientError(err.to_string())),", indent 2 "}" ] generateReturnHandling _ isNullable = if isNullable then unlines [ indent 2 "match result {", indent 3 "FunctionResult::Value(val) => Ok(FromConvexValue::from_convex(val.clone())?),", indent 3 "FunctionResult::ErrorMessage(s) => Err(ApiError::ConvexFunctionError(s)),", indent 3 "FunctionResult::ConvexError(err) => Err(ApiError::ConvexClientError(err.to_string())),", indent 2 "}" ] else unlines [ indent 2 "match result {", indent 3 "FunctionResult::Value(Value::Null) => Err(ApiError::UnexpectedNullError),", indent 3 "FunctionResult::Value(val) => Ok(FromConvexValue::from_convex(val.clone())?),", indent 3 "FunctionResult::ErrorMessage(s) => Err(ApiError::ConvexFunctionError(s)),", indent 3 "FunctionResult::ConvexError(err) => Err(ApiError::ConvexClientError(err.to_string())),", indent 2 "}" ] generateToConvexValueImpl :: String -> [(String, Schema.ConvexType)] -> String generateToConvexValueImpl structName fields = let buildMapInserts (fieldName, fieldType) = let conversionBlock = generateFieldToConvexValue (fieldName, fieldType) in indent 3 conversionBlock mapInserts = unlines $ map buildMapInserts fields in unlines [ "impl " ++ structName ++ " {", indent 1 "pub fn to_convex_value(&self) -> Result {", indent 2 "let mut btmap = BTreeMap::new();", mapInserts, indent 2 "Ok(Value::Object(btmap))", indent 1 "}", "}" ] generateFromConvexValueImplEnum :: String -> [Schema.ConvexType] -> String generateFromConvexValueImplEnum structName fields = unlines [ "impl TryFrom for " ++ structName ++ " {", indent 1 "type Error = ApiError;", indent 1 "fn try_from(value: Value) -> Result {", indent 2 "if let Value::String(s) = &value {", indent 3 "return match s.as_str() {", unlines . map (indent 4) $ generateEnumMatchCases structName fields, indent 2 "}", indent 1 "}", indent 2 "Err(ApiError::ConvexClientError(\"Expected a string for " ++ structName ++ "\".to_string()))", indent 1 "}", indent 1 "}", "", indent 1 $ "impl FromConvexValue for " ++ structName ++ " {", indent 2 $ "fn from_convex(value: Value) -> Result {", indent 3 $ structName ++ "::try_from(value)", indent 2 $ "}", indent 1 $ "}", "" ] generateEnumMatchCases :: String -> [Schema.ConvexType] -> [String] generateEnumMatchCases structName fields = let cases = map ( \case (Schema.VLiteral s) -> "\"" ++ s ++ "\" => Ok(" ++ structName ++ "::" ++ (toPascalCase . Schema.sanitizeUnionValues $ s) ++ ")," _ -> error "Expected a literal for enum field" ) fields in cases ++ ["x => Err(ApiError::ConvexClientError(\"Expected one of " ++ enumValues ++ " for " ++ structName ++ " but got \".to_string() + &x.to_string()))"] where enumValues = "[" ++ (intercalate ", " $ map Schema.getLiteralString fields) ++ "]" -- #[derive(Default, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] -- pub enum AccessLevelEnum { -- #[default] -- #[serde(rename = "read")] -- Read, -- #[serde(rename = "edit")] -- Edit, -- } -- -- impl From for AccessLevelEnum { -- fn from(value: Value) -> Self { -- match value { -- Value::String(s) => match s.as_str() { -- "read" => AccessLevelEnum::Read, -- "edit" => AccessLevelEnum::Edit, -- _ => panic!("Unknown AccessLevelEnum value"), -- }, -- _ => panic!("Expected a string for AccessLevelEnum"), -- } -- } -- } generateFromConvexValueImpl :: String -> [(String, Schema.ConvexType)] -> String generateFromConvexValueImpl structName fields = unlines [ "impl TryFrom for " ++ structName ++ " {", indent 1 "type Error = ApiError;", indent 1 "fn try_from(value: Value) -> Result {", indent 2 "let obj = match value {", indent 3 "Value::Object(map) => map,", indent 3 "_ => return Err(ApiError::ConvexClientError(\"Expected object\".to_string())),", indent 2 "};", unlines $ map (indent 2 . generateAccessorFnFromConvexValue structName) fields, unlines $ map (indent 2) $ generateFromConvexValueResult structName fields, indent 1 "}", "}", "", indent 1 $ "impl FromConvexValue for " ++ structName ++ " {", indent 2 $ "fn from_convex(value: Value) -> Result {", indent 3 $ structName ++ "::try_from(value)", indent 2 $ "}", indent 1 $ "}", "" ] generateFromConvexValueResult :: String -> [(String, Schema.ConvexType)] -> [String] generateFromConvexValueResult structName fields = let fieldNames = map (toSnakeCase . fst) fields fieldAccessors = map (\name -> "get_" ++ name) fieldNames setterStatements = zipWith ( \name getter -> case name of "_creation_time" -> indent 3 $ "_creation_time: " ++ getter ++ "(&obj, \"_creationTime\")?" n -> indent 3 $ n ++ ": " ++ getter ++ "(&obj, \"" ++ n ++ "\")?" ) fieldNames fieldAccessors in [ indent 2 ("Ok(" ++ structName ++ " {"), indent 3 $ intercalate ",\n" $ setterStatements, indent 2 "})" ] -- Ok(types::GetAssetsReturnObject { -- _id: Id::new(get__id(&obj, "_id")?), -- _creation_time: get__creation_time(&obj, "_creationTime")?, -- project_id: Id::new(get_project_id(&obj, "project_id")?), -- asset_name: get_asset_name(&obj, "asset_name")?, -- asset_essence_mtime: get_asset_essence_mtime(&obj, "asset_essence_mtime")?, -- link_metadata: get_link_metadata(&obj, "link_metadata")?, -- }) generateAccessorFnFromConvexValue :: String -> (String, Schema.ConvexType) -> String generateAccessorFnFromConvexValue structName (fieldName, Schema.VOptional fieldType) = let fieldNameSnake = toSnakeCase fieldName getterName = "get_" ++ fieldNameSnake (fieldTypeStr, _) = case stripPrefix (reverse "Object") (reverse structName) of Just cleanStructName -> toRustType (reverse cleanStructName ++ capitalize fieldName) fieldType Nothing -> toRustType (structName ++ capitalize fieldName) fieldType in unlines [ indent 0 $ "fn " ++ getterName ++ "(map: &BTreeMap, key: &str) -> Result, ApiError> {", indent 1 $ "match map.get(key) {", indent 2 $ "Some(v) => {", indent 3 $ "Ok(Some(FromConvexValue::from_convex(v.clone())?))", indent 2 $ "}", indent 2 $ "_ => Ok(None),", indent 1 $ "}", indent 0 $ "}", "" ] generateAccessorFnFromConvexValue structName (fieldName, fieldType) = let fieldNameSnake = toSnakeCase fieldName getterName = "get_" ++ fieldNameSnake (fieldTypeStr, _) = case stripPrefix (reverse "Object") (reverse structName) of Just cleanStructName -> toRustType (reverse cleanStructName ++ capitalize fieldName) fieldType Nothing -> toRustType (structName ++ capitalize fieldName) fieldType in unlines [ indent 0 $ "fn " ++ getterName ++ "(map: &BTreeMap, key: &str) -> Result<" ++ fieldTypeStr ++ ", ApiError> {", indent 1 $ "match map.get(key) {", indent 2 $ "Some(v) => {", indent 3 $ "Ok(FromConvexValue::from_convex(v.clone())?)", indent 2 $ "}", indent 2 $ "_ => return Err(ApiError::ConvexClientError(format!(\"Expected field (" ++ fieldTypeStr ++ ") '{}' not found\", key))),", indent 1 $ "}", indent 0 $ "}", "" ] -- An example implementation: -- -- pub struct GetAssetsReturnObject { -- #[serde(default)] -- pub _id: Id, -- #[serde(rename = "_creationTime")] -- #[serde(default)] -- pub _creation_time: f64, -- #[serde(default)] -- pub project_id: Id, -- #[serde(default)] -- pub asset_name: String, -- #[serde(default)] -- pub link_metadata: types::GetAssetsReturnLinkMetadataObject, -- #[serde(default)] -- pub asset_essence_mtime: i64, -- } -- -- impl TryFrom for types::GetAssetsReturnObject { -- type Error = ApiError; -- -- fn try_from(value: Value) -> Result { -- let obj = match value { -- Value::Object(map) => map, -- _ => return Err(ApiError::ConvexClientError("Expected object".to_string())), -- }; -- -- fn get__id(map: &BTreeMap, key: &str) -> Result { -- match map.get(key) { -- Some(Value::String(s)) => Ok(s.clone()), -- _ => Err(ApiError::ConvexClientError(format!( -- "Expected string for field '{}'", -- key -- ))), -- } -- } -- -- fn get_project_id(map: &BTreeMap, key: &str) -> Result { -- match map.get(key) { -- Some(Value::String(s)) => Ok(s.clone()), -- _ => Err(ApiError::ConvexClientError(format!( -- "Expected string for field '{}'", -- key -- ))), -- } -- } -- -- fn get__creation_time(map: &BTreeMap, key: &str) -> Result { -- match map.get(key) { -- Some(Value::Float64(f)) => Ok(*f), -- _ => Err(ApiError::ConvexClientError(format!( -- "Expected float64 for field '{}'", -- key -- ))), -- } -- } -- -- fn get_asset_essence_mtime( -- map: &BTreeMap, -- key: &str, -- ) -> Result { -- match map.get(key) { -- Some(Value::Int64(i)) => Ok(*i), -- _ => Err(ApiError::ConvexClientError(format!( -- "Expected int64 for field '{}'", -- key -- ))), -- } -- } -- -- fn get_asset_name(map: &BTreeMap, key: &str) -> Result { -- match map.get(key) { -- Some(Value::String(s)) => Ok(s.clone()), -- _ => Err(ApiError::ConvexClientError(format!( -- "Expected string for field '{}'", -- key -- ))), -- } -- } -- -- fn get_link_metadata( -- map: &BTreeMap, -- key: &str, -- ) -> Result { -- match map.get(key) { -- Some(Value::Object(inner)) => { -- let length = match inner.get("length") { -- Some(Value::Int64(i)) => *i, -- _ => { -- return Err(ApiError::ConvexClientError( -- "Expected int64 for 'length'".into(), -- )); -- } -- }; -- let sample_rate = match inner.get("sample_rate") { -- Some(Value::Int64(i)) => *i, -- _ => { -- return Err(ApiError::ConvexClientError( -- "Expected int64 for 'sample_rate'".into(), -- )); -- } -- }; -- let summary = match inner.get("summary") { -- Some(Value::Bytes(b)) => b.clone(), -- _ => { -- return Err(ApiError::ConvexClientError( -- "Expected bytes for 'summary'".into(), -- )); -- } -- }; -- Ok(types::GetAssetsReturnLinkMetadataObject { -- length, -- sample_rate, -- summary, -- }) -- } -- _ => Err(ApiError::ConvexClientError( -- "Expected object for 'link_metadata'".into(), -- )), -- } -- } -- -- Ok(types::GetAssetsReturnObject { -- _id: Id::new(get__id(&obj, "_id")?), -- _creation_time: get__creation_time(&obj, "_creationTime")?, -- project_id: Id::new(get_project_id(&obj, "project_id")?), -- asset_name: get_asset_name(&obj, "asset_name")?, -- asset_essence_mtime: get_asset_essence_mtime(&obj, "asset_essence_mtime")?, -- link_metadata: get_link_metadata(&obj, "link_metadata")?, -- }) -- } -- } -- Some(Value::Array(v)) => Ok(Some( -- v.iter() -- .map(|item| item.clone().try_into()) -- .collect::, ApiError>>()?, -- )), generateFieldToConvexValue :: (String, Schema.ConvexType) -> String generateFieldToConvexValue (fieldName, Schema.VOptional inner) = let fieldNameSnake = toSnakeCase fieldName -- `v` is the unwrapped value from the Option. valueExpr = innerValueToConvexOptional "v" inner in unlines [ "if let Some(v) = &self." ++ fieldNameSnake ++ " {", indent 1 ("btmap.insert(\"" ++ fieldName ++ "\".to_string(), " ++ valueExpr ++ ");"), "}" ] generateFieldToConvexValue (fieldName, fieldType) = let fieldNameSnake = toSnakeCase fieldName valueExpr = innerValueToConvexNonOptional ("self." ++ fieldNameSnake) fieldType in "btmap.insert(\"" ++ fieldName ++ "\".to_string(), " ++ valueExpr ++ ");" -- | Generates the conversion for a non-optional inner value. innerValueToConvexOptional :: String -> Schema.ConvexType -> String innerValueToConvexOptional varName (Schema.VArray inner) = let itemConversion = innerValueToConvexArray "item" inner -- We can check `isFallible` by looking for `Value::try_from` or `?` in the conversion. isFallible = "?" `isInfixOf` itemConversion || "Value::try_from" `isInfixOf` itemConversion || "to_convex_value" `isInfixOf` itemConversion in if isFallible then "Value::Array(" ++ varName ++ ".iter().map(|item| " ++ itemConversion ++ ").collect::, _>>()?)" else "Value::Array(" ++ varName ++ ".iter().map(|item| " ++ itemConversion ++ ").collect())" innerValueToConvexOptional varName (Schema.VObject _) = varName ++ ".to_convex_value()" innerValueToConvexOptional varName t | isComplexForBTreeMap t = "Value::try_from(serde_json::to_value(" ++ toClonedValue varName t ++ ").unwrap_or(\"unable to serialize\".into()))?" | otherwise = "Value::from(" ++ toClonedValue varName t ++ ")" innerValueToConvexArray :: String -> Schema.ConvexType -> String innerValueToConvexArray varName (Schema.VArray inner) = let itemConversion = innerValueToConvexArray "item" inner isFallible = "?" `isInfixOf` itemConversion || "Value::try_from" `isInfixOf` itemConversion || "to_convex_value" `isInfixOf` itemConversion in if isFallible then "Value::Array(" ++ varName ++ ".iter().map(|item| " ++ itemConversion ++ ").collect::, _>>()?)" else "Value::Array(" ++ varName ++ ".iter().map(|item| " ++ itemConversion ++ ").collect())" innerValueToConvexArray varName (Schema.VObject _) = varName ++ ".to_convex_value()" innerValueToConvexArray varName t | isComplexForBTreeMap t = "Value::try_from(serde_json::to_value(" ++ toClonedValue varName t ++ ").unwrap_or(\"unable to serialize\".into()))" | otherwise = "Value::from(" ++ toClonedValue varName t ++ ")" innerValueToConvexNonOptional :: String -> Schema.ConvexType -> String innerValueToConvexNonOptional varName (Schema.VArray inner) = let itemConversion = innerValueToConvexArray "item" inner isFallible = "?" `isInfixOf` itemConversion || "Value::try_from" `isInfixOf` itemConversion || "to_convex_value" `isInfixOf` itemConversion in if isFallible then "Value::Array(" ++ varName ++ ".iter().map(|item| " ++ itemConversion ++ ").collect::, _>>()?)" else "Value::Array(" ++ varName ++ ".iter().map(|item| " ++ itemConversion ++ ").collect())" innerValueToConvexNonOptional varName (Schema.VBytes) = "Value::from(" ++ varName ++ ".to_vec())" innerValueToConvexNonOptional varName (Schema.VObject _) = "Value::from(" ++ varName ++ ".to_convex_value()?)" innerValueToConvexNonOptional varName t | isComplexForBTreeMap t = "Value::try_from(serde_json::to_value(" ++ toClonedValue varName t ++ ").unwrap_or(\"unable to serialize\".into()))?" | otherwise = "Value::from(" ++ toClonedValueNonOptional varName t ++ ")" toClonedValueNonOptional :: String -> Schema.ConvexType -> String toClonedValueNonOptional varName (Schema.VString) = varName ++ ".to_string()" toClonedValueNonOptional varName t | isPassedByCopy t = varName | otherwise = varName ++ ".clone()" isComplexForBTreeMap :: Schema.ConvexType -> Bool isComplexForBTreeMap (Schema.VUnion _) = True isComplexForBTreeMap (Schema.VReference _) = True isComplexForBTreeMap Schema.VAny = True isComplexForBTreeMap _ = False toRustType :: String -> Schema.ConvexType -> (String, [String]) toRustType nameHint typ = case typ of Schema.VString -> ("String", []) Schema.VNumber -> ("f64", []) Schema.VInt64 -> ("i64", []) Schema.VFloat64 -> ("f64", []) Schema.VBoolean -> ("bool", []) Schema.VAny -> ("serde_json::Value", []) Schema.VNull -> ("()", []) Schema.VId t -> ("Id", []) Schema.VBytes -> ("Vec", []) Schema.VArray inner -> let (innerType, nested) = toRustType nameHint inner in ("Vec<" ++ innerType ++ ">", nested) Schema.VOptional inner -> let (innerType, nested) = toRustType nameHint inner in ("Option<" ++ innerType ++ ">", nested) Schema.VObject fields -> let className = toPascalCase nameHint ++ "Object" (fieldLines, nestedFields) = unzip $ map (generateField nameHint) (map (\(n, t) -> Schema.Field n t) fields) implBlock = generateToConvexValueImpl (className) fields fromBlock = generateFromConvexValueImpl (className) fields newModel = unlines [ indent 1 "#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]", indent 1 ("pub struct " ++ className ++ " {"), unlines fieldLines, indent 1 "}", "", implBlock, "", fromBlock ] in ("types::" ++ className, concat nestedFields ++ [newModel]) Schema.VUnion types -> let nonNullTypes = filter (/= Schema.VNull) types in case nonNullTypes of [] -> ("Option<()>", []) [singleType] -> let (innerType, nested) = toRustType nameHint singleType in ("Option<" ++ innerType ++ ">", nested) _ -> if all Schema.isLiteral nonNullTypes && not (null nonNullTypes) then let enumName = toPascalCase nameHint variantNames = map Schema.getLiteralString nonNullTypes buildVariantLines [] = [] buildVariantLines (first : rest) = (indent 2 "#[default]\n" ++ indent 2 ("#[serde(rename = \"" ++ first ++ "\")]\n") ++ indent 2 (toPascalCase first ++ ",")) : map (\v -> indent 2 ("#[serde(rename = \"" ++ v ++ "\")]\n") ++ indent 2 (toPascalCase v ++ ",")) rest fromBlock = generateFromConvexValueImplEnum ("types::" ++ enumName) nonNullTypes newEnum = unlines [ indent 1 "#[derive(Default, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]", indent 1 ("pub enum " ++ enumName ++ " {"), unlines $ buildVariantLines variantNames, indent 1 "}", "", fromBlock, "" ] in ("types::" ++ enumName, [newEnum]) else ("serde_json::Value", []) -- Fallback for complex unions Schema.VLiteral _ -> (toPascalCase nameHint, []) Schema.VReference n -> ("types::" ++ toPascalCase n, []) Schema.VVoid -> ("()", []) toRustBorrowType :: String -> String toRustBorrowType rustType | rustType == "String" = "&str" | rustType == "serde_json::Value" = "&serde_json::Value" | rustType == "Vec" = "&[u8]" | "Id<" `isPrefixOf` rustType = "&" ++ rustType | "Vec<" `isPrefixOf` rustType = let inner = take (length rustType - 5) (drop 4 rustType) in "&[" ++ inner ++ "]" | "Option<" `isPrefixOf` rustType = let inner = take (length rustType - 8) (drop 7 rustType) in "Option<" ++ toRustBorrowType inner ++ ">" | otherwise = rustType needsOptionalWrapper :: Schema.ConvexType -> Bool needsOptionalWrapper (Schema.VOptional _) = True needsOptionalWrapper (Schema.VUnion ts) = Schema.VNull `elem` ts needsOptionalWrapper _ = False toPascalCase :: String -> String toPascalCase s = concatMap capitalize parts where parts = words $ map (\c -> if c == '_' then ' ' else c) s capitalize :: String -> String capitalize "" = "" capitalize (c : cs) = toUpper c : cs toSnakeCase :: String -> String toSnakeCase "" = "" toSnakeCase (c : cs) = toLower c : go cs where go (c' : cs') | isUpper c' = '_' : toLower c' : go cs' | otherwise = c' : go cs' go "" = "" stripNewlines :: String -> String stripNewlines s = unlines . filter (/= "") $ lines s