diff --git a/README.md b/README.md index 1a8cbc2..d24565f 100644 --- a/README.md +++ b/README.md @@ -105,3 +105,34 @@ let ms = MyStruct::get_partial("someid", &serde_json::json!({"other": 1})).await let myref = ms.other.unwrap(); // will be there let name = ms.name.unwrap() // will panic! ``` + +### Updating values +You can either update the values by passing a JSON object overwriting the current values or update the values using a builder pattern. + +```rust +[...] +struct MyStruct { + _id: String, + name: String, + age: u32, + other: Option, +} + +let mut a = MyStruct::get("someid").await.unwrap(); + +// JSON +a.update(serde_json::json!({"name": "bye"})).await.unwrap(); + +// Builder +let mut changes = a.change(); + +// Set fields +changes = changes.name("bye"); + +// There are type specific functions +// Increment age by 1 +changes = changes.age_increment(1); + +// Finalize +let changed_model: MyStruct = changes.update().await.unwrap(); +``` diff --git a/mongod_derive/src/change_fields.rs b/mongod_derive/src/change_fields.rs new file mode 100644 index 0000000..96894c3 --- /dev/null +++ b/mongod_derive/src/change_fields.rs @@ -0,0 +1,213 @@ +use quote::quote; +use syn::Field; + +use crate::{extract_inner_type, is_one_of_type, is_type}; + +/// Generate the ChangeBuilder field fns +pub fn builder_change_fields(field: &Field) -> proc_macro2::TokenStream { + let field_name = &field.ident.as_ref().unwrap(); + let field_type = &field.ty; + let field_name_str = field_name.to_string(); + + // Never update _id + if field_name_str == "_id" { + return quote! {}; + } + + let number_types = [ + "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize", + "f16", "f32", "f64", "f128", + ]; + + // Number type fn + if is_one_of_type(field_type, &number_types) { + let documentation = format!("Set the value of `{field_name}`"); + let inc_fn_name = syn::Ident::new(&format!("{}_increment", field_name), field_name.span()); + let doc_inc = format!("Increment value of `{field_name}` by `value`. Consecutive calls to this function will not add up, they overwrite the increment."); + + let mul_fn_name = syn::Ident::new(&format!("{}_multiply", field_name), field_name.span()); + let doc_mul = format!("Multiply value of `{field_name}` by `value`. Consecutive calls to this function will not add up, they overwrite the multiply."); + + return quote! { + #[doc = #documentation] + pub fn #field_name(mut self, value: #field_type)-> Self { + self.model.#field_name = value.into(); + + self.changeset.entry("$set".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&self.model.#field_name).unwrap(), + ); + + self + } + + #[doc = #doc_inc] + pub fn #inc_fn_name(mut self, value: #field_type) -> Self { + self.model.#field_name += value; + + self.changeset.entry("$inc".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap() + .insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&value).unwrap(), + ); + + self + } + + #[doc = #doc_mul] + pub fn #mul_fn_name(mut self, value: #field_type) -> Self { + self.model.#field_name *= value; + + self.changeset.entry("$mul".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap() + .insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&value).unwrap(), + ); + + self + } + }; + } + + if is_type(field_type, "Vec") { + let inner_field_type = extract_inner_type(field_type, "Vec").unwrap(); + + let push_fn_name = syn::Ident::new(&format!("{}_push", field_name), field_name.span()); + + let documentation = format!("Add a value to the Vec `{field_name}`"); + let documentation2 = format!("Set the value of `{field_name}`"); + + return quote! { + #[doc = #documentation] + pub fn #push_fn_name(mut self, value: T) -> Self where T: Into<#inner_field_type> + serde::Serialize { + let mut push = self.changeset.entry("$push".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap(); + + if push.contains_key(#field_name_str) { + let current = push.get_mut(#field_name_str.to_string()).unwrap(); + + if current.as_document().map(|x| !x.contains_key("$each")).unwrap_or(true) { + let each = mongod::mongodb::bson::doc! { + "$each": [current, mongod::mongodb::bson::to_bson(&value).unwrap()] + }; + + push.insert(#field_name_str.to_string(), each); + } else { + current.as_document_mut().unwrap().get_mut("$each").unwrap().as_array_mut().unwrap().push(mongod::mongodb::bson::to_bson(&value).unwrap()); + } + } else { + push.insert(#field_name_str.to_string(), mongod::mongodb::bson::to_bson(&value).unwrap()); + } + + self.model.#field_name.push(value.into()); + + self + } + + #[doc = #documentation2] + pub fn #field_name(mut self, value: T)-> Self where T: Into<#field_type> + serde::Serialize { + self.model.#field_name = value.into(); + + self.changeset.entry("$set".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&self.model.#field_name).unwrap(), + ); + + self + } + }; + } + + if is_type(field_type, "Historic") { + let inner_field_type = extract_inner_type(field_type, "Historic").unwrap(); + + let documentation = format!( + "Update the value of `{field_name}`. This change will be recorded by the `Historic`" + ); + + // Code for Historic + return quote! { + #[doc = #documentation] + pub fn #field_name(mut self, value: T) -> Self where T: Into<#inner_field_type> + serde::Serialize { + self.model.#field_name.update(value.into()); + + self.changeset.entry("$set".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&self.model.#field_name).unwrap(), + ); + + + self + } + }; + } + + if is_type(field_type, "Option") { + let inner_field_type = extract_inner_type(field_type, "Option").unwrap(); + + if is_type(inner_field_type, "Historic") { + let inner_field_type = extract_inner_type(inner_field_type, "Historic").unwrap(); + + let documentation = format!("Update the value of `{field_name}`. This change will be recorded by the `Historic`. If `{field_name}` is `None` a new `Historic` will be initialized."); + + // Code for Option> + return quote! { + #[doc = #documentation] + pub fn #field_name(mut self, value: T) -> Self where T: Into<#inner_field_type> + serde::Serialize { + if let Some(mut opt) = self.model.#field_name.as_mut() { + opt.update(value.into()); + } else { + self.model.#field_name = Some(mongod::Historic::new(value.into())); + } + + self.changeset.entry("$set".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&self.model.#field_name).unwrap(), + ); + + self + } + }; + } + + let documentation = format!("Set the value of `{field_name}`. If `Some(_)` it will be updated or removed if it is `None`"); + + return quote! { + #[doc = #documentation] + pub fn #field_name(mut self, value: Option<#inner_field_type>) -> Self { + let is_some = value.is_some(); + + self.model.#field_name = value.into(); + + if is_some { + self.changeset.entry("$set".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&self.model.#field_name).unwrap(), + ); + } else { + self.changeset.entry("$unset".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson("").unwrap(), + ); + } + + self + } + }; + } + + let documentation = format!("Set the value of `{field_name}`"); + // Code for T + quote! { + #[doc = #documentation] + pub fn #field_name(mut self, value: T)-> Self where T: Into<#field_type> + serde::Serialize { + self.model.#field_name = value.into(); + + self.changeset.entry("$set".to_string()).or_insert(mongod::mongodb::bson::doc! {}.into()).as_document_mut().unwrap().insert( + #field_name_str.to_string(), + mongod::mongodb::bson::to_bson(&self.model.#field_name).unwrap(), + ); + + self + } + } +} diff --git a/mongod_derive/src/lib.rs b/mongod_derive/src/lib.rs index 7264223..8cae41d 100644 --- a/mongod_derive/src/lib.rs +++ b/mongod_derive/src/lib.rs @@ -3,40 +3,18 @@ use case::CaseExt; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields, Type, TypePath}; +use syn::{parse_macro_input, Data, DeriveInput, Fields}; -/// Get inner type. Example: Returns `T` for `Option`. -fn extract_inner_type<'a>(ty: &'a Type, parent: &'a str) -> Option<&'a Type> { - if let Type::Path(type_path) = ty { - if type_path.path.segments.len() == 1 { - let segment = &type_path.path.segments[0]; - if segment.ident == parent { - if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments { - if args.args.len() == 1 { - if let syn::GenericArgument::Type(ref inner_type) = args.args[0] { - return Some(inner_type); - } - } - } - } - } - } - None -} - -fn type_path(ty: &syn::Type) -> TypePath { - if let syn::Type::Path(type_path) = ty { - return type_path.clone(); - } - unreachable!(); -} - -fn is_type(ty: &syn::Type, t: &str) -> bool { - let type_path = type_path(ty); - let id = type_path.path.segments.first().unwrap().ident.to_string(); - id == t -} +mod types; +use types::*; +mod partial_fields; +use partial_fields::partial_code_field; +mod update_fields; +use update_fields::update_code_field; +mod change_fields; +use change_fields::builder_change_fields; +/// #[derive(Model)] #[proc_macro_derive(Model)] pub fn model_derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -44,84 +22,37 @@ pub fn model_derive(input: TokenStream) -> TokenStream { let name = input.ident; let name_str = name.to_string().to_snake(); let partial_name = syn::Ident::new(&format!("Partial{}", name), name.span()); + let changebuilder_name = syn::Ident::new(&format!("Change{}", name), name.span()); + + let update_doc = "Commit the builder changes to DB"; // Generate code for each field let field_code = if let Data::Struct(data_struct) = input.data { match data_struct.fields { Fields::Named(fields_named) => { - let field_process_code: Vec<_> = fields_named.named.iter().map(|field| { - let field_name = &field.ident.as_ref().unwrap(); - let field_type = &field.ty; - let field_name_str = field_name.to_string(); + // Update code + let field_process_code: Vec<_> = + fields_named.named.iter().map(update_code_field).collect(); - if field_name_str == "_id" { - return quote! {}; - } + // Partial struct fields + let partial_struct: Vec<_> = + fields_named.named.iter().map(partial_code_field).collect(); - if is_type(field_type, "Historic") { - let inner_field_type = extract_inner_type(field_type, "Historic").unwrap(); - if is_type(inner_field_type, "Vec") { - return quote! { - mongod::update_historic_vec!(self, obj, #field_name_str, #field_name, update); - }; - } - - return quote! { - mongod::update_historic_str!(self, obj, #field_name_str, #field_name, update); - } - } - - if is_type(field_type, "Option") { - let inner_field_type = extract_inner_type(field_type, "Option").unwrap(); - - if is_type(inner_field_type, "Historic") { - let sub_inner_field_type = extract_inner_type(inner_field_type, "Historic").unwrap(); - if is_type(sub_inner_field_type, "Reference") { - return quote! { - mongod::update_historic_ref_option!(self, obj, #field_name_str, #field_name, update); - }; - } - } - - return quote! { - mongod::update_value_option!(self, obj, #field_name_str, #field_name, update, #inner_field_type); - }; - } - - quote! { - mongod::update_value!(self, obj, #field_name_str, #field_name, update, #field_type); - } - }).collect(); - - let partial_struct: Vec<_> = fields_named + // Builder functions + let builder_change_fields: Vec<_> = fields_named .named .iter() - .map(|field| { - let field_name = &field.ident.as_ref().unwrap(); - let field_type = &field.ty; - let field_name_str = field_name.to_string(); - - if field_name_str == "_id" { - return quote! { - pub _id: String, - }; - } - - if is_type(field_type, "Option") { - return quote! { - pub #field_name: #field_type, - }; - } - - quote! { - pub #field_name: Option<#field_type>, - } - }) + .map(builder_change_fields) .collect(); quote! { impl mongod::model::Model for #name { type Partial = #partial_name; + type ChangeBuilder = #changebuilder_name; + + fn change_builder(self) -> Self::ChangeBuilder { + #changebuilder_name::new(self) + } async fn update_values( &mut self, @@ -137,6 +68,47 @@ pub fn model_derive(input: TokenStream) -> TokenStream { #( #partial_struct )* } + #[derive(Debug)] + pub struct #changebuilder_name { + model: #name, + changeset: mongod::mongodb::bson::Document + } + + impl #changebuilder_name { + pub fn new(model: #name) -> Self { + Self { + model, + changeset: mongod::mongodb::bson::doc! {} + } + } + + #[doc = #update_doc] + pub async fn update(self) -> Result<#name, mongod::model::UpdateError> { + let db = mongod::get_mongo!(); + let collection = mongod::col!(db, <#name as mongod::Referencable>::collection_name()); + + if let Err(msg) = mongod::Validate::validate(&self.model).await { + return Err(mongod::model::UpdateError::Validation(msg)); + } + + let changeset = self.changeset; + + collection + .update_one( + mongod::id_of!(mongod::Referencable::id(&self.model)), + changeset, + None, + ) + .await + .map_err(mongod::model::UpdateError::Database)?; + + Ok(self.model) + } + + #( #builder_change_fields )* + + } + impl mongod::model::reference::Referencable for #partial_name { fn collection_name() -> &'static str { #name_str @@ -157,6 +129,7 @@ pub fn model_derive(input: TokenStream) -> TokenStream { TokenStream::from(field_code) } +/// #[derive(Referencable)] #[proc_macro_derive(Referencable)] pub fn referencable_derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); diff --git a/mongod_derive/src/partial_fields.rs b/mongod_derive/src/partial_fields.rs new file mode 100644 index 0000000..5234cb0 --- /dev/null +++ b/mongod_derive/src/partial_fields.rs @@ -0,0 +1,30 @@ +use quote::quote; +use syn::Field; + +use crate::is_type; + +/// Generate struct fields for Partial Model +pub fn partial_code_field(field: &Field) -> proc_macro2::TokenStream { + let field_name = &field.ident.as_ref().unwrap(); + let field_type = &field.ty; + let field_name_str = field_name.to_string(); + + // Keep _id + if field_name_str == "_id" { + return quote! { + pub _id: String, + }; + } + + // Leave Option alone + if is_type(field_type, "Option") { + return quote! { + pub #field_name: #field_type, + }; + } + + // Turn every field into Option + quote! { + pub #field_name: Option<#field_type>, + } +} diff --git a/mongod_derive/src/types.rs b/mongod_derive/src/types.rs new file mode 100644 index 0000000..f4bc087 --- /dev/null +++ b/mongod_derive/src/types.rs @@ -0,0 +1,42 @@ +use syn::{Type, TypePath}; + +/// Get inner type. Example: Returns `T` for `Option`. +pub fn extract_inner_type<'a>(ty: &'a Type, parent: &'a str) -> Option<&'a Type> { + if let Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + let segment = &type_path.path.segments[0]; + if segment.ident == parent { + if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments { + if args.args.len() == 1 { + if let syn::GenericArgument::Type(ref inner_type) = args.args[0] { + return Some(inner_type); + } + } + } + } + } + } + None +} + +pub fn type_path(ty: &syn::Type) -> TypePath { + if let syn::Type::Path(type_path) = ty { + return type_path.clone(); + } + unreachable!(); +} + +pub fn is_one_of_type(ty: &syn::Type, t: &[&str]) -> bool { + for typ in t { + if is_type(ty, typ) { + return true; + } + } + false +} + +pub fn is_type(ty: &syn::Type, t: &str) -> bool { + let type_path = type_path(ty); + let id = type_path.path.segments.first().unwrap().ident.to_string(); + id == t +} diff --git a/mongod_derive/src/update_fields.rs b/mongod_derive/src/update_fields.rs new file mode 100644 index 0000000..f7bb9bd --- /dev/null +++ b/mongod_derive/src/update_fields.rs @@ -0,0 +1,55 @@ +use quote::quote; +use syn::Field; + +use crate::{extract_inner_type, is_type}; + +/// Generate code for the update fn of models +pub fn update_code_field(field: &Field) -> proc_macro2::TokenStream { + let field_name = &field.ident.as_ref().unwrap(); + let field_type = &field.ty; + let field_name_str = field_name.to_string(); + + // Never update _id + if field_name_str == "_id" { + return quote! {}; + } + + if is_type(field_type, "Historic") { + let inner_field_type = extract_inner_type(field_type, "Historic").unwrap(); + if is_type(inner_field_type, "Vec") { + // Custom code Historic> + return quote! { + mongod::update_historic_vec!(self, obj, #field_name_str, #field_name, update); + }; + } + + // Code for Historic + return quote! { + mongod::update_historic_str!(self, obj, #field_name_str, #field_name, update); + }; + } + + if is_type(field_type, "Option") { + let inner_field_type = extract_inner_type(field_type, "Option").unwrap(); + + if is_type(inner_field_type, "Historic") { + let sub_inner_field_type = extract_inner_type(inner_field_type, "Historic").unwrap(); + if is_type(sub_inner_field_type, "Reference") { + // Code for Option> + return quote! { + mongod::update_historic_ref_option!(self, obj, #field_name_str, #field_name, update); + }; + } + } + + // Code for Option + return quote! { + mongod::update_value_option!(self, obj, #field_name_str, #field_name, update, #inner_field_type); + }; + } + + // Code for T + quote! { + mongod::update_value!(self, obj, #field_name_str, #field_name, update, #field_type); + } +} diff --git a/src/model/historic.rs b/src/model/historic.rs index 85d8286..b9f5cc6 100644 --- a/src/model/historic.rs +++ b/src/model/historic.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, ops::Deref}; /// A struct to keep track of historical changes to a value. /// This struct represents a value that has a current state and a history of previous states. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Historic { /// The current value pub current: T, @@ -24,6 +24,13 @@ impl Historic { } } +impl Historic { + /// Create a new tracked value initialized with Default + pub fn new_default() -> Historic { + Self::new(T::default()) + } +} + impl Historic { /// Update the value. The change will be recorded. /// Will record a change even if the value is the same as the current one. diff --git a/src/model/mod.rs b/src/model/mod.rs index e0b850e..34896eb 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -18,8 +18,6 @@ pub mod reference; pub mod update; pub mod valid; -// todo : use mongodb projection to only get fields you actually use, maybe PartialModel shadow struct? - /// Error type when updating a model #[derive(Debug)] pub enum UpdateError { @@ -35,6 +33,7 @@ pub trait Model: Sized + Referencable + Validate + serde::Serialize + for<'a> serde::Deserialize<'a> { type Partial: DeserializeOwned; + type ChangeBuilder; /// Insert the `Model` into the database fn insert( @@ -293,6 +292,18 @@ pub trait Model: } } + fn change_builder(self) -> Self::ChangeBuilder; + + /// Update values of `Model` using a builder pattern + fn change(self) -> Self::ChangeBuilder { + #[cfg(feature = "cache")] + { + mongod::cache_write!().invalidate(Self::collection_name(), self.id()); + } + + self.change_builder() + } + /// Update values of `Model` into database fn update( &mut self,