comrade/comrade-macro/src/lib.rs
2025-03-07 21:20:25 +01:00

207 lines
7.3 KiB
Rust

use proc_macro::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::{FnArg, Ident, ItemFn, Pat, ReturnType, Type, parse_macro_input};
/// This macro turns this function into a worker.
///
/// This will upgrade the function and generate a few ones (`fn` is a placeholder for the functions name):
/// - `fn()` - This function will be exactly the same as the original but it will be computed by a worker.
/// - `fn_init(&ServiceManager) -> ServiceManager` - This function registers a worker thread on a `ServiceManager`.
/// - `fn_shutdown()` - This function issues a shutdown request.
/// - `fn_init_scoped(&ServiceManager) -> (ServiceManager, fn_Scoped)` - This function registers a worker thread on a `ServiceManager` and returns a scoped struct. You can call the underlying function with `.call()` on the struct and it will automatically shutdown any workers if it gets out of scope.
///
/// # Examples
/// ```ignore
/// use comrade::worker;
///
/// // Declare worker
/// #[worker]
/// pub fn multiply(a: i32, b: i32) -> i32 {
/// a * b
/// }
///
/// fn main() {
/// let s = ServiceManager::new().mode(comrade::service::ServiceMode::Decay);
///
/// // Init worker thread
/// let s = multiply_init(s);
/// let s = s.spawn();
///
/// // Usage
/// let x = multiply(2, 8);
/// println!("myfn {x}");
///
/// // Shutdown worker thread
/// multiply_shutdown();
///
/// s.join().unwrap();
/// }
/// ```
#[proc_macro_attribute]
pub fn worker(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input: ItemFn = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
// Extract parameter names and types separately
let params: Vec<(Ident, Type)> = input
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_type) = arg {
let name = if let Pat::Ident(pat_ident) = *pat_type.pat.clone() {
pat_ident.ident.clone()
} else {
return None;
};
let ty = *pat_type.ty.clone();
Some((name, ty))
} else {
None
}
})
.collect();
// Extract parameter names and types into separate lists for quoting
let param_names: Vec<Ident> = params.iter().map(|(name, _)| name.clone()).collect();
let param_types: Vec<Type> = params.iter().map(|(_, ty)| ty.clone()).collect();
for t in &param_types {
println!("param {}", t.to_token_stream().to_string());
}
// Extract return type
let return_type = match &input.sig.output {
ReturnType::Type(_, ty) => quote!(#ty),
ReturnType::Default => quote!(()),
};
// Extract function body
let body = &input.block;
let wrapper_fn = format_ident!("{}_wrapper", fn_name);
let worker_fn = format_ident!("{}_worker", fn_name);
let init_fn = format_ident!("{}_init", fn_name);
let init_fn_scoped = format_ident!("{}_init_scoped", fn_name);
let fn_scope_struct = format_ident!("{}_Scoped", fn_name);
let fn_name_async = format_ident!("{}_async", fn_name);
let shutdown_fn = format_ident!("{}_shutdown", fn_name);
let param_unpacking = param_names.iter().enumerate().map(|(i, name)| {
if param_names.len() == 1 {
return quote! {
let #name = i;
};
}
let param_type = &param_types[i];
if let Type::Path(_) = param_type {
quote! {
let #name = i.#i;
}
} else {
quote! {
let #name = i;
}
}
});
let output = quote! {
pub fn #fn_name(#(#param_names: #param_types),*) -> #return_type {
let i: comrade::serde_json::Value = comrade::serde_json::to_value( (#(#param_names),*) ).unwrap();
serde_json::from_value(comrade::UNION.get(stringify!(#fn_name)).unwrap().send(i)).unwrap()
}
#[doc = "Will run the function non blocking returning a `JobResult<_>` for fetching a result later."]
pub fn #fn_name_async(#(#param_names: #param_types),*) -> comrade::job::JobResult<comrade::serde_json::Value> {
let i: comrade::serde_json::Value = comrade::serde_json::to_value( (#(#param_names),*) ).unwrap();
comrade::UNION.get(stringify!(#fn_name)).unwrap().send_async(i)
}
fn #wrapper_fn(task: JobOrder<comrade::serde_json::Value, comrade::serde_json::Value>) {
let i = task.param.clone();
// Deserialize the parameter into the function's expected types
let i: (#(#param_types),*) = comrade::serde_json::from_value(i).unwrap();
#(#param_unpacking)*
let res = #body;
task.done(comrade::serde_json::to_value(&res).unwrap());
}
pub fn #worker_fn(recv: Receiver<JobOrder<comrade::serde_json::Value, comrade::serde_json::Value>>) {
loop {
let task = recv.recv();
match task {
Ok(task) => {
if let comrade::serde_json::Value::Object(obj) = &task.param {
if obj.contains_key("task") {
log::info!("Shutdown requested for task worker {}", stringify!(#fn_name));
task.done(comrade::serde_json::json!({"ok": 1}));
break;
}
}
#wrapper_fn(task)
},
Err(e) => {
log::error!("Error receiving task: {e:?}");
}
}
}
}
#[doc = "Shutdown the worker"]
pub fn #shutdown_fn() {
comrade::UNION.get(stringify!(#fn_name)).unwrap().send(comrade::serde_json::json!({"task": "shutdown"}));
}
#[doc = "Initialize a worker thread on `ServiceManager`"]
pub fn #init_fn(sm: ServiceManager) -> ServiceManager {
let (dispatch, recv): (JobDispatcher<_, _>, Receiver<JobOrder<_, _>>) = JobDispatcher::new();
let sm = sm.register(stringify!(#worker_fn), move |_| {
#worker_fn(recv.clone());
});
comrade::UNION.insert(stringify!(#fn_name), dispatch);
sm
}
#[allow(non_camel_case_types)]
pub struct #fn_scope_struct {}
impl #fn_scope_struct {
pub fn call(&self, #(#param_names: #param_types),*) -> #return_type {
#fn_name(#(#param_names),*)
}
}
impl Drop for #fn_scope_struct {
fn drop(&mut self) {
log::info!("Scoped task worker got dropped.");
#shutdown_fn();
}
}
#[doc = "Initialize a worker thread on `ServiceManager` on a scoped lifetime"]
pub fn #init_fn_scoped(sm: ServiceManager) -> (ServiceManager, #fn_scope_struct) {
let (dispatch, recv): (JobDispatcher<_, _>, Receiver<JobOrder<_, _>>) = JobDispatcher::new();
let sm = sm.register(stringify!(#worker_fn), move |_| {
#worker_fn(recv.clone());
});
comrade::UNION.insert(stringify!(#fn_name), dispatch);
(sm, #fn_scope_struct {})
}
};
output.into()
}