cli: ensure ordering of rpc server messages (#183558)

* cli: ensure ordering of rpc server messages

Sending lots of messages to a stream would block them around the async
tokio mutex, which is "fair" so doesn't preserve ordering. Instead, use
the write_loop approach I introduced to the server_multiplexer for the
same reason some time ago.

* fix clippy
This commit is contained in:
Connor Peet 2023-05-26 09:48:06 -07:00 committed by GitHub
parent 6a7e91ebff
commit 1942c0eccc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 23 deletions

View file

@ -273,30 +273,21 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
/// Builds into a usable, sync rpc dispatcher.
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
let streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>> =
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let streams = Streams::default();
let s1 = streams.clone();
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
let s1 = s1.clone();
async move {
if let Some(mut s) = s1.lock().await.remove(&m.stream) {
let _ = s.shutdown().await;
}
s1.remove(m.stream).await;
Ok(())
}
});
let s2 = streams.clone();
self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
let s2 = s2.clone();
async move {
let mut lock = s2.lock().await;
if let Some(stream) = lock.get_mut(&m.stream) {
let _ = stream.write_all(&m.segment).await;
}
Ok(())
}
self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
s2.write(m.stream, m.segment);
Ok(())
});
RpcDispatcher {
@ -400,7 +391,7 @@ pub struct RpcDispatcher<S, C> {
serializer: Arc<S>,
methods: Arc<HashMap<&'static str, Method>>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>>,
streams: Streams,
}
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
@ -483,10 +474,9 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
return;
}
let mut streams_map = self.streams.lock().await;
for (stream_id, duplex) in dto.streams {
let (mut read, write) = tokio::io::split(duplex);
streams_map.insert(stream_id, write);
self.streams.insert(stream_id, write);
let write_tx = write_tx.clone();
let serial = self.serializer.clone();
@ -538,6 +528,90 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
}
}
struct StreamRec {
write: Option<WriteHalf<DuplexStream>>,
q: Vec<Vec<u8>>,
}
#[derive(Clone, Default)]
struct Streams {
map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
}
impl Streams {
pub async fn remove(&self, id: u32) {
let stream = self.map.lock().unwrap().remove(&id);
if let Some(s) = stream {
// if there's no 'write' right now, it'll shut down in the write_loop
if let Some(mut w) = s.write {
let _ = w.shutdown().await;
}
}
}
pub fn write(&self, id: u32, buf: Vec<u8>) {
let mut map = self.map.lock().unwrap();
if let Some(s) = map.get_mut(&id) {
s.q.push(buf);
if let Some(w) = s.write.take() {
tokio::spawn(write_loop(id, w, self.map.clone()));
}
}
}
pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) {
self.map.lock().unwrap().insert(
id,
StreamRec {
write: Some(stream),
q: Vec::new(),
},
);
}
}
/// Write loop started by `Streams.write`. It takes the WriteHalf, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the `streams` (i.e. we haven't shut down), it'll
/// return the WriteHalf so that the next `write` call starts
/// the loop again. Otherwise, it'll shut down the WriteHalf.
///
/// This is the equivalent of the same write_loop in the server_multiplexer.
/// I couldn't figure out a nice way to abstract it without introducing
/// performance overhead...
async fn write_loop(
id: u32,
mut w: WriteHalf<DuplexStream>,
streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
) {
let mut items_vec = vec![];
loop {
{
let mut lock = streams.lock().unwrap();
let stream_rec = match lock.get_mut(&id) {
Some(b) => b,
None => break,
};
if stream_rec.q.is_empty() {
stream_rec.write = Some(w);
return;
}
std::mem::swap(&mut stream_rec.q, &mut items_vec);
}
for item in items_vec.drain(..) {
if w.write_all(&item).await.is_err() {
break;
}
}
}
let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so
}
const METHOD_STREAMS_STARTED: &str = "streams_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";

View file

@ -105,7 +105,7 @@ impl ServerMultiplexer {
}
}
/// Write loop started by `handle_server_message`. It take sthe ServerBridge, and
/// Write loop started by `handle_server_message`. It takes the ServerBridge, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll
/// return the ServerBridge so that the next handle_server_message call starts

View file

@ -4,11 +4,9 @@
*--------------------------------------------------------------------------------------------*/
use async_trait::async_trait;
use std::{marker::PhantomData, sync::Arc};
use tokio::{
sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
},
use tokio::sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
};
#[derive(Clone)]