comrade/src/job.rs
2025-03-09 06:34:41 +01:00

512 lines
16 KiB
Rust

use crossbeam::channel::{Receiver, Sender};
use rand::Rng;
use redis::{Commands, RedisResult};
use serde::{Deserialize, Serialize};
use std::{sync::mpsc, time::Duration};
pub enum TaskReceiverBackend<I, O> {
Local(Receiver<JobOrder<I, O>>),
Union(ValkeyJobDispatcher<I, O>),
}
impl<I: Serialize + for<'a> Deserialize<'a>, O: Serialize + for<'a> Deserialize<'a>>
TaskReceiverBackend<I, O>
{
pub fn recv(&self) -> Result<JobOrder<I, O>, String> {
match self {
TaskReceiverBackend::Local(receiver) => receiver.recv().map_err(|x| x.to_string()),
TaskReceiverBackend::Union(valkey_job_dispatcher) => valkey_job_dispatcher.recv(),
}
}
}
pub struct ValkeyTopicSubscriber<O> {
output: std::marker::PhantomData<O>,
topic: String,
client: redis::Client,
}
impl<O: for<'a> Deserialize<'a>> ValkeyTopicSubscriber<O> {
pub fn new(channel: &str) -> Self {
let client =
redis::Client::open(std::env::var("VALKEY_URL").expect("No $VALKEY_URL variable set"))
.expect("Invalid Redis URL");
Self {
output: std::marker::PhantomData,
topic: channel.to_string(),
client: client,
}
}
pub fn recv(&self) -> Option<O> {
let mut con = self
.client
.get_connection()
.expect("Failed to connect to Redis");
let result: RedisResult<Vec<String>> = con.blpop(&self.topic, 0.0);
match result {
Ok(msg) => {
let msg = msg.iter().nth(1).unwrap();
Some(serde_json::from_str(&msg).unwrap())
}
Err(_) => None,
}
}
pub fn recv_timeout(&self, timeout: std::time::Duration) -> Option<O> {
let mut con = self
.client
.get_connection()
.expect("Failed to connect to Redis");
let result: RedisResult<Vec<String>> = con.blpop(&self.topic, timeout.as_secs_f64());
match result {
Ok(msg) => {
let msg = msg.iter().nth(1).unwrap();
Some(serde_json::from_str(&msg).unwrap())
}
Err(_) => None,
}
}
}
#[derive(Clone)]
pub struct ValkeyJobDispatcher<I, O> {
input: std::marker::PhantomData<I>,
output: std::marker::PhantomData<O>,
topic: String,
client: redis::Client,
local: bool,
}
impl<I: Serialize + for<'a> Deserialize<'a>, O: Serialize + for<'a> Deserialize<'a>>
ValkeyJobDispatcher<I, O>
{
// Creates a new job dispatcher for the given topic.
pub fn new_topic(topic: &str, local: bool) -> Self {
let client =
redis::Client::open(std::env::var("VALKEY_URL").expect("No $VALKEY_URL variable set"))
.expect("Invalid Redis URL");
ValkeyJobDispatcher {
input: std::marker::PhantomData,
output: std::marker::PhantomData,
topic: topic.to_string(),
client,
local,
}
}
// todo : real pub sub
pub fn recv(&self) -> Result<JobOrder<I, O>, String> {
let mut con = self
.client
.get_connection()
.expect("Failed to connect to Redis");
let result: RedisResult<Vec<String>> = con.blpop(&self.topic, 0.0);
match result {
Ok(msg) => {
let msg = msg.iter().nth(1).unwrap();
if let serde_json::Value::Object(task) = serde_json::from_str(&msg).unwrap() {
let channel_id = task.get("task").unwrap().as_str().unwrap().to_string();
let params = task.get("params").unwrap();
Ok(JobOrder::new(
serde_json::from_value(params.clone()).unwrap(),
move |res| {
// send back to channel
let _: () = con
.rpush(
&channel_id,
serde_json::to_string(&serde_json::to_value(&res).unwrap())
.unwrap(),
)
.expect("Failed to send job");
},
))
} else {
Err(String::new())
}
}
Err(e) => {
log::error!("Valkey error: {e:?}");
Err(e.to_string())
}
}
}
}
impl<I: Serialize + for<'a> Deserialize<'a>, O: for<'a> Deserialize<'a> + Serialize>
JobDispatch<I, O> for ValkeyJobDispatcher<I, O>
{
// Sends a job to the Redis topic (publishes a message).
fn send(&self, param: I) -> O {
let mut con = self
.client
.get_connection()
.expect("Failed to connect to Redis");
// Pushing the job to the topic in Redis
let channel_id = uuid::Uuid::new_v4().to_string();
let _: () = con
.rpush(
&self.topic,
serde_json::to_string(&serde_json::json!({
"task": channel_id,
"params": &param
}))
.unwrap(),
)
.expect("Failed to send job");
ValkeyTopicSubscriber::new(&channel_id).recv().unwrap()
}
// Sends a job asynchronously (non-blocking).
fn send_async(&self, param: I) -> JobResult<O> {
let mut con = self
.client
.get_connection()
.expect("Failed to connect to Redis");
// Pushing the job to the topic in Redis
let channel_id = uuid::Uuid::new_v4().to_string();
let _: () = con
.rpush(
&self.topic,
serde_json::to_string(&serde_json::json!({
"task": channel_id,
"params": &param
}))
.unwrap(),
)
.expect("Failed to send job");
JobResult(ReceiverBackend::Valkey(ValkeyTopicSubscriber::new(
&channel_id,
)))
}
// Tries to send a job, returning None if unsuccessful.
fn try_send(&self, param: I) -> Option<O> {
let res = self.send_async(param);
res.wait_try()
}
}
#[derive(Clone)]
/// A generic job dispatcher struct that allows sending jobs of type `I` and receiving results of type `O` using message passing.
pub struct JobDispatcher<I: Send + 'static, O: Send + 'static> {
sender: Sender<JobOrder<I, O>>,
}
pub enum ReceiverBackend<O> {
Local(std::sync::mpsc::Receiver<O>),
Valkey(ValkeyTopicSubscriber<O>),
}
impl<O: for<'a> Deserialize<'a>> ReceiverBackend<O> {
pub fn recv(&self) -> Option<O> {
match self {
ReceiverBackend::Local(receiver) => receiver.recv().ok(),
ReceiverBackend::Valkey(valkey) => valkey.recv(),
}
}
pub fn recv_timeout(&self) -> Option<O> {
match self {
ReceiverBackend::Local(receiver) => {
receiver.recv_timeout(Duration::from_millis(300)).ok()
}
ReceiverBackend::Valkey(valkey_topic_subscriber) => {
valkey_topic_subscriber.recv_timeout(Duration::from_millis(300))
}
}
}
}
pub struct JobResult<O>(ReceiverBackend<O>);
impl<O: for<'a> Deserialize<'a>> JobResult<O> {
/// Wait for the Result of a Job.
pub fn wait(self) -> O {
self.0.recv().unwrap()
}
pub fn wait_try(self) -> Option<O> {
self.0.recv()
}
pub fn wait_timeout(&self) -> Option<O> {
self.0.recv_timeout()
}
}
impl<I: Send + 'static, O: Send + 'static> JobDispatcher<I, O> {
/// Creates a new instance of `JobDispatcher` and returns a tuple that contains it and a receiver end for `JobOrder`s.
/// # Example:
/// ```
/// use jobdispatcher::*;
/// // Create job dispatcher
/// let (dispatcher, recv) = JobDispatcher::<i32, i32>::new();
///
/// // Worker Thread
/// std::thread::spawn(move || {
/// for job in recv {
/// let result = job.param + 1;
/// job.done(result);
/// }
/// });
///
/// // Usage
/// let result = dispatcher.send(3);
/// assert_eq!(result, 4);
/// ```
#[must_use]
pub fn new() -> (Self, Receiver<JobOrder<I, O>>) {
let (sender, receiver) = crossbeam::channel::bounded(8092);
(Self { sender: sender }, receiver)
}
}
impl<I: Send + 'static, O: Send + 'static> JobDispatch<I, O> for JobDispatcher<I, O> {
/// Sends a job of type `T` to the job dispatcher and waits for its result of type `V`.
/// Returns the result of the job once it has been processed.
/// # Panics
/// This function panics when the `JobOrder` struct gets out of scope without returning a finished result.
/// Additionally if the internal `Mutex` is poisoned, this function will panic as well.
fn send(&self, param: I) -> O {
let (tx, rx) = mpsc::channel();
let job_order = JobOrder::new(param, move |ret| {
tx.send(ret).unwrap();
});
self.sender.send(job_order).unwrap();
rx.recv().unwrap()
}
fn send_async(&self, param: I) -> JobResult<O> {
let (tx, rx) = mpsc::channel();
let job_order = JobOrder::new(param, move |ret| {
tx.send(ret).unwrap();
});
self.sender.send(job_order).unwrap();
JobResult(ReceiverBackend::Local(rx))
}
/// Sends a job of type `T` to the job dispatcher and waits for its result of type `V`.
/// Returns `Some(V)` when the job returns an result, `None` if somehow nothing was returned or the internal `Mutex` is poisoned.
fn try_send(&self, param: I) -> Option<O> {
let (tx, rx) = mpsc::channel();
let job_order = JobOrder::new(param, move |ret| {
tx.send(ret).unwrap();
});
self.sender.send(job_order).ok()?;
rx.recv().ok()
}
}
pub trait JobDispatch<I, O> {
fn send(&self, param: I) -> O;
fn send_async(&self, param: I) -> JobResult<O>;
fn try_send(&self, param: I) -> Option<O>;
}
/// A struct that represents a job order that encapsulates a job of type `I` and its result of type `O`, along with a callback function that will send the result back to the job origin.
pub struct JobOrder<I, O> {
/// The job parameter of type `T`.
pub param: I,
callback: Box<dyn FnOnce(O) + Send>,
}
impl<I, O> JobOrder<I, O> {
/// Creates a new `JobOrder` instance with the specified job parameter `param` of type `I` and a callback function that takes the job result of type `O` as an argument.
#[must_use]
fn new(param: I, callback: impl FnOnce(O) + Send + 'static) -> Self {
Self {
param,
callback: Box::new(callback),
}
}
/// Send the result of the `JobOrder` back to it's origin
pub fn done(self, val: O) {
(self.callback)(val);
}
}
pub enum Dispatcher<I: Send + 'static, O: Send + 'static> {
Local(JobDispatcher<I, O>),
Union(ValkeyJobDispatcher<I, O>),
}
impl<
I: Serialize + for<'a> Deserialize<'a> + Send + 'static,
O: Serialize + for<'a> Deserialize<'a> + Send + 'static,
> Dispatcher<I, O>
{
pub fn is_local(&self) -> bool {
match self {
Dispatcher::Local(_) => true,
Dispatcher::Union(valkey_job_dispatcher) => valkey_job_dispatcher.local,
}
}
fn send(&self, param: I) -> O {
match self {
Dispatcher::Local(job_dispatcher) => job_dispatcher.send(param),
Dispatcher::Union(valkey_job_dispatcher) => valkey_job_dispatcher.send(param),
}
}
fn send_async(&self, param: I) -> JobResult<O> {
match self {
Dispatcher::Local(job_dispatcher) => job_dispatcher.send_async(param),
Dispatcher::Union(valkey_job_dispatcher) => valkey_job_dispatcher.send_async(param),
}
}
}
pub struct JobMultiplexer<I: Send + 'static, O: Send + 'static> {
dispatchers: Vec<Dispatcher<I, O>>,
}
fn get_random_item<T>(list: &[T]) -> Option<&T> {
if list.is_empty() {
return None;
}
let mut rng = rand::rng();
let index = rng.random_range(0..list.len());
list.get(index)
}
impl<
I: Serialize + for<'a> Deserialize<'a> + Send + 'static,
O: Serialize + for<'a> Deserialize<'a> + Send + 'static,
> JobMultiplexer<I, O>
{
pub fn from(dispatchers: Vec<Dispatcher<I, O>>) -> Self {
Self { dispatchers }
}
pub fn send(&self, param: I) -> O {
let d = get_random_item(&self.dispatchers).unwrap();
d.send(param)
}
pub fn send_async(&self, param: I) -> JobResult<O> {
let d = get_random_item(&self.dispatchers).unwrap();
d.send_async(param)
}
}
impl<
I: Clone + Serialize + for<'a> Deserialize<'a> + Send + 'static,
O: Serialize + for<'a> Deserialize<'a> + Send + 'static,
> JobMultiplexer<I, O>
{
pub fn send_all(&self, param: I) {
for d in &self.dispatchers {
if d.is_local() {
let _ = d.send(param.clone());
}
}
}
}
/// Iterator which returns ready results from a `Vec<JobResult<O>>.
///
/// This Iterator waits for each `JobResult<O>` with a timeout and yields a result once it finds one finished `JobResult<O>`.
///
/// # Example
/// ```ignore
/// // Started Tasks which are pending
/// let pending_tasks = vec![...];
///
/// for task in pending_tasks {
/// // blocks and waits for the first `JobResult<_>` even though the next ones in the `Vec<_>` could be finished and processed already.
/// let result = task.wait();
/// // ...
/// }
///
/// // With the Iterator
///
/// for value in PendingTaskIterator(pending_tasks) {
/// // You can immidiatelly start processing the first finished result
/// // ...
/// }
/// ```
pub struct PendingTaskIterator<O>(pub Vec<JobResult<O>>);
impl<O: for<'a> Deserialize<'a>> Iterator for PendingTaskIterator<O> {
type Item = O;
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
return None;
}
loop {
for (i, task) in self.0.iter().enumerate() {
if let Some(res) = task.wait_timeout() {
self.0.remove(i);
return Some(res);
}
}
}
}
}
/// Iterator which returns ready results from a `Vec<JobResult<O>> along with a label.
///
/// Compared to a normal `PendingTaskIterator<O>` this Iterator takes a `Vec<(L, JobResult<I, O>)>`.
/// You can use the variable `L` for an associated label to correlate results to their origins.
///
/// This Iterator waits for each `JobResult<O>` with a timeout and yields a result once it finds one finished `JobResult<O>`.
///
/// # Example
/// ```ignore
/// // Started Tasks which are pending
/// let pending_tasks = vec![...];
///
/// for task in pending_tasks {
/// // blocks and waits for the first `JobResult<_>` even though the next ones in the `Vec<_>` could be finished and processed already.
/// let result = task.wait();
/// // ...
/// }
///
/// // With the Iterator
///
/// for value in LabelPendingTaskIterator(pending_tasks) {
/// // You can immidiatelly start processing the first finished result
/// // ...
/// }
/// ```
pub struct LabelPendingTaskIterator<L, O>(pub Vec<(L, JobResult<O>)>);
impl<L: Clone, O: for<'a> Deserialize<'a>> Iterator for LabelPendingTaskIterator<L, O> {
type Item = (L, O);
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
return None;
}
loop {
for (i, task) in self.0.iter().enumerate() {
let result = &task.1;
let label = task.0.clone();
if let Some(res) = result.wait_timeout() {
self.0.remove(i);
return Some((label, res));
}
}
}
}
}