mirror of
https://github.com/Microsoft/vscode
synced 2024-10-12 06:17:18 +00:00
cli: enable server message compression
This is the CLI side of enabling compression of servermsg's sent over the socket. It is feature-detected by the CLI sending protocolVersion=2. If present, the consumer can request compression by passing `compress:true` when setting up the server. In this mode, servermsg's are an inflate/deflate stream. Not a ton of code here, but was lots of fun tweaking to get it right :) Fixes https://github.com/microsoft/vscode/issues/163688
This commit is contained in:
parent
b982536f83
commit
ebfa4b0c3c
|
@ -10,7 +10,13 @@ use lazy_static::lazy_static;
|
|||
use crate::options::Quality;
|
||||
|
||||
pub const CONTROL_PORT: u16 = 31545;
|
||||
pub const PROTOCOL_VERSION: u32 = 1;
|
||||
|
||||
/// Protocol version sent to clients. This can be used to indiciate new or
|
||||
/// changed capabilities that clients may wish to leverage.
|
||||
/// 1 - Initial protocol version
|
||||
/// 2 - Addition of `serve.compressed` property to control whether servermsg's
|
||||
/// are compressed bidirectionally.
|
||||
pub const PROTOCOL_VERSION: u32 = 2;
|
||||
|
||||
pub const VSCODE_CLI_VERSION: Option<&'static str> = option_env!("VSCODE_CLI_VERSION");
|
||||
pub const VSCODE_CLI_AI_KEY: Option<&'static str> = option_env!("VSCODE_CLI_AI_KEY");
|
||||
|
|
|
@ -8,6 +8,7 @@ pub mod dev_tunnels;
|
|||
pub mod legal;
|
||||
pub mod paths;
|
||||
|
||||
mod socket_signal;
|
||||
mod control_server;
|
||||
mod name_generator;
|
||||
mod port_forwarder;
|
||||
|
|
|
@ -7,6 +7,8 @@ use crate::constants::{CONTROL_PORT, PROTOCOL_VERSION, VSCODE_CLI_VERSION};
|
|||
use crate::log;
|
||||
use crate::self_update::SelfUpdate;
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::tunnels::protocol::HttpRequestParams;
|
||||
use crate::tunnels::socket_signal::CloseReason;
|
||||
use crate::update_service::{Platform, UpdateService};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError,
|
||||
|
@ -18,7 +20,6 @@ use crate::util::io::SilentCopyProgress;
|
|||
use crate::util::sync::{new_barrier, Barrier};
|
||||
use opentelemetry::trace::SpanKind;
|
||||
use opentelemetry::KeyValue;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::env;
|
||||
|
@ -38,12 +39,12 @@ use super::paths::prune_stopped_servers;
|
|||
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
|
||||
use super::protocol::{
|
||||
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse,
|
||||
ForwardParams, ForwardResult, GetHostnameResponse, HttpRequestParams, RefServerMessageParams,
|
||||
ResponseError, ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod,
|
||||
SuccessResponse, ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult,
|
||||
VersionParams,
|
||||
ForwardParams, ForwardResult, GetHostnameResponse, ResponseError, ServeParams, ServerLog,
|
||||
ServerMessageParams, ServerRequestMethod, SuccessResponse, ToClientRequest, ToServerRequest,
|
||||
UnforwardParams, UpdateParams, UpdateResult, VersionParams,
|
||||
};
|
||||
use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge};
|
||||
use super::server_bridge::{get_socket_rw_stream, ServerBridge};
|
||||
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink, SocketSignal};
|
||||
|
||||
type ServerBridgeList = Option<Vec<(u16, ServerBridge)>>;
|
||||
type ServerBridgeListLock = Arc<Mutex<ServerBridgeList>>;
|
||||
|
@ -122,39 +123,6 @@ enum ServerSignal {
|
|||
Respawn,
|
||||
}
|
||||
|
||||
struct CloseReason(String);
|
||||
|
||||
enum SocketSignal {
|
||||
/// Signals bytes to send to the socket.
|
||||
Send(Vec<u8>),
|
||||
/// Closes the socket (e.g. as a result of an error)
|
||||
CloseWith(CloseReason),
|
||||
/// Disposes ServerBridge corresponding to an ID
|
||||
CloseServerBridge(u16),
|
||||
}
|
||||
|
||||
impl SocketSignal {
|
||||
fn from_message<T>(msg: &T) -> Self
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromServerMessage for SocketSignal {
|
||||
fn from_server_message(i: u16, body: &[u8]) -> Self {
|
||||
SocketSignal::from_message(&ToClientRequest {
|
||||
id: None,
|
||||
params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }),
|
||||
})
|
||||
}
|
||||
|
||||
fn from_closed_server_bridge(i: u16) -> Self {
|
||||
SocketSignal::CloseServerBridge(i)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerTermination {
|
||||
/// Whether the server should be respawned in a new binary (see ServerSignal.Respawn).
|
||||
pub respawn: bool,
|
||||
|
@ -719,7 +687,15 @@ async fn handle_serve(
|
|||
}
|
||||
};
|
||||
|
||||
attach_server_bridge(&log, server, socket_tx, server_bridges, params.socket_id).await?;
|
||||
attach_server_bridge(
|
||||
&log,
|
||||
server,
|
||||
socket_tx,
|
||||
server_bridges,
|
||||
params.socket_id,
|
||||
params.compress,
|
||||
)
|
||||
.await?;
|
||||
Ok(EmptyResult {})
|
||||
}
|
||||
|
||||
|
@ -729,8 +705,22 @@ async fn attach_server_bridge(
|
|||
socket_tx: mpsc::Sender<SocketSignal>,
|
||||
server_bridges: ServerBridgeListLock,
|
||||
socket_id: u16,
|
||||
compress: bool,
|
||||
) -> Result<u16, AnyError> {
|
||||
let attached_fut = ServerBridge::new(&code_server.socket, socket_id, &socket_tx).await;
|
||||
let (server_messages, decoder) = if compress {
|
||||
(
|
||||
ServerMessageSink::new_compressed(socket_tx),
|
||||
ClientMessageDecoder::new_compressed(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
ServerMessageSink::new_plain(socket_tx),
|
||||
ClientMessageDecoder::new_plain(),
|
||||
)
|
||||
};
|
||||
|
||||
let attached_fut =
|
||||
ServerBridge::new(&code_server.socket, socket_id, server_messages, decoder).await;
|
||||
|
||||
match attached_fut {
|
||||
Ok(a) => {
|
||||
|
|
|
@ -91,6 +91,9 @@ pub struct ServeParams {
|
|||
pub extensions: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub use_local_download: bool,
|
||||
/// If true, the client and server should gzip servermsg's sent in either direction.
|
||||
#[serde(default)]
|
||||
pub compress: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
|
|
|
@ -7,18 +7,15 @@ use std::path::Path;
|
|||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{unix::OwnedWriteHalf, UnixStream},
|
||||
sync::mpsc::Sender,
|
||||
};
|
||||
|
||||
use crate::util::errors::{wrap, AnyError};
|
||||
|
||||
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
|
||||
|
||||
pub struct ServerBridge {
|
||||
write: OwnedWriteHalf,
|
||||
}
|
||||
|
||||
pub trait FromServerMessage {
|
||||
fn from_server_message(index: u16, message: &[u8]) -> Self;
|
||||
fn from_closed_server_bridge(i: u16) -> Self;
|
||||
decoder: ClientMessageDecoder,
|
||||
}
|
||||
|
||||
pub async fn get_socket_rw_stream(path: &Path) -> Result<UnixStream, AnyError> {
|
||||
|
@ -38,25 +35,26 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result<UnixStream, AnyError> {
|
|||
const BUFFER_SIZE: usize = 65536;
|
||||
|
||||
impl ServerBridge {
|
||||
pub async fn new<T>(path: &Path, index: u16, target: &Sender<T>) -> Result<Self, AnyError>
|
||||
where
|
||||
T: 'static + FromServerMessage + Send,
|
||||
{
|
||||
pub async fn new(
|
||||
path: &Path,
|
||||
index: u16,
|
||||
mut target: ServerMessageSink,
|
||||
decoder: ClientMessageDecoder,
|
||||
) -> Result<Self, AnyError> {
|
||||
let stream = get_socket_rw_stream(path).await?;
|
||||
let (mut read, write) = stream.into_split();
|
||||
|
||||
let tx = target.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut read_buf = vec![0; BUFFER_SIZE];
|
||||
loop {
|
||||
match read.read(&mut read_buf).await {
|
||||
Err(_) => return,
|
||||
Ok(0) => {
|
||||
let _ = tx.send(T::from_closed_server_bridge(index)).await;
|
||||
let _ = target.closed_server_bridge(index).await;
|
||||
return; // EOF
|
||||
}
|
||||
Ok(s) => {
|
||||
let send = tx.send(T::from_server_message(index, &read_buf[..s])).await;
|
||||
let send = target.server_message(index, &read_buf[..s]).await;
|
||||
if send.is_err() {
|
||||
return;
|
||||
}
|
||||
|
@ -65,11 +63,14 @@ impl ServerBridge {
|
|||
}
|
||||
});
|
||||
|
||||
Ok(ServerBridge { write })
|
||||
Ok(ServerBridge { write, decoder })
|
||||
}
|
||||
|
||||
pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
|
||||
self.write.write_all(&b).await?;
|
||||
let dec = self.decoder.decode(&b)?;
|
||||
if !dec.is_empty() {
|
||||
self.write.write_all(dec).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -14,13 +14,11 @@ use tokio::{
|
|||
|
||||
use crate::util::errors::{wrap, AnyError};
|
||||
|
||||
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
|
||||
|
||||
pub struct ServerBridge {
|
||||
write_tx: mpsc::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
pub trait FromServerMessage {
|
||||
fn from_server_message(index: u16, message: &[u8]) -> Self;
|
||||
fn from_closed_server_bridge(i: u16) -> Self;
|
||||
decoder: ClientMessageDecoder,
|
||||
}
|
||||
|
||||
const BUFFER_SIZE: usize = 65536;
|
||||
|
@ -49,13 +47,14 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result<NamedPipeClient, AnyErr
|
|||
}
|
||||
|
||||
impl ServerBridge {
|
||||
pub async fn new<T>(path: &Path, index: u16, target: &mpsc::Sender<T>) -> Result<Self, AnyError>
|
||||
where
|
||||
T: 'static + FromServerMessage + Send,
|
||||
{
|
||||
pub async fn new(
|
||||
path: &Path,
|
||||
index: u16,
|
||||
mut target: ServerMessageSink,
|
||||
decoder: ClientMessageDecoder,
|
||||
) -> Result<Self, AnyError> {
|
||||
let client = get_socket_rw_stream(path).await?;
|
||||
let (write_tx, mut write_rx) = mpsc::channel(4);
|
||||
let read_tx = target.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut read_buf = vec![0; BUFFER_SIZE];
|
||||
let mut pending_recv: Option<Vec<u8>> = None;
|
||||
|
@ -89,9 +88,7 @@ impl ServerBridge {
|
|||
match client.try_read(&mut read_buf) {
|
||||
Ok(0) => return, // EOF
|
||||
Ok(s) => {
|
||||
let send = read_tx
|
||||
.send(T::from_server_message(index, &read_buf[..s]))
|
||||
.await;
|
||||
let send = target.server_message(index, &read_buf[..s]).await;
|
||||
if send.is_err() {
|
||||
return;
|
||||
}
|
||||
|
@ -118,11 +115,14 @@ impl ServerBridge {
|
|||
}
|
||||
});
|
||||
|
||||
Ok(ServerBridge { write_tx })
|
||||
Ok(ServerBridge { write_tx, decoder })
|
||||
}
|
||||
|
||||
pub async fn write(&self, b: Vec<u8>) -> std::io::Result<()> {
|
||||
self.write_tx.send(b).await.ok();
|
||||
let dec = self.decoder.decode(&b)?;
|
||||
if !dec.is_empty() {
|
||||
self.write_tx.send(dec.to_vec()).await.ok();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
244
cli/src/tunnels/socket_signal.rs
Normal file
244
cli/src/tunnels/socket_signal.rs
Normal file
|
@ -0,0 +1,244 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest};
|
||||
|
||||
pub struct CloseReason(pub String);
|
||||
|
||||
pub enum SocketSignal {
|
||||
/// Signals bytes to send to the socket.
|
||||
Send(Vec<u8>),
|
||||
/// Closes the socket (e.g. as a result of an error)
|
||||
CloseWith(CloseReason),
|
||||
/// Disposes ServerBridge corresponding to an ID
|
||||
CloseServerBridge(u16),
|
||||
}
|
||||
|
||||
impl SocketSignal {
|
||||
pub fn from_message<T>(msg: &T) -> Self
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct that handling sending or closing a connected server socket.
|
||||
pub struct ServerMessageSink {
|
||||
tx: mpsc::Sender<SocketSignal>,
|
||||
flate: Option<FlateStream<CompressFlateAlgorithm>>,
|
||||
}
|
||||
|
||||
impl ServerMessageSink {
|
||||
pub fn new_plain(tx: mpsc::Sender<SocketSignal>) -> Self {
|
||||
Self { tx, flate: None }
|
||||
}
|
||||
|
||||
pub fn new_compressed(tx: mpsc::Sender<SocketSignal>) -> Self {
|
||||
Self {
|
||||
tx,
|
||||
flate: Some(FlateStream::new(CompressFlateAlgorithm(
|
||||
flate2::Compress::new(flate2::Compression::new(2), false),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn server_message(
|
||||
&mut self,
|
||||
i: u16,
|
||||
body: &[u8],
|
||||
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
|
||||
let msg = {
|
||||
let body = self.get_server_msg_content(body);
|
||||
SocketSignal::from_message(&ToClientRequest {
|
||||
id: None,
|
||||
params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }),
|
||||
})
|
||||
};
|
||||
|
||||
self.tx.send(msg).await
|
||||
}
|
||||
|
||||
pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] {
|
||||
if let Some(flate) = &mut self.flate {
|
||||
if let Ok(compressed) = flate.process(body) {
|
||||
return compressed;
|
||||
}
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
pub async fn closed_server_bridge(
|
||||
&mut self,
|
||||
i: u16,
|
||||
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
|
||||
self.tx.send(SocketSignal::CloseServerBridge(i)).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientMessageDecoder {
|
||||
dec: Option<FlateStream<DecompressFlateAlgorithm>>,
|
||||
}
|
||||
|
||||
impl ClientMessageDecoder {
|
||||
pub fn new_plain() -> Self {
|
||||
ClientMessageDecoder { dec: None }
|
||||
}
|
||||
|
||||
pub fn new_compressed() -> Self {
|
||||
ClientMessageDecoder {
|
||||
dec: Some(FlateStream::new(DecompressFlateAlgorithm(
|
||||
flate2::Decompress::new(false),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> {
|
||||
match &mut self.dec {
|
||||
Some(d) => d.process(message),
|
||||
None => Ok(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait FlateAlgorithm {
|
||||
fn total_in(&self) -> u64;
|
||||
fn total_out(&self) -> u64;
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error>;
|
||||
}
|
||||
|
||||
struct DecompressFlateAlgorithm(flate2::Decompress);
|
||||
|
||||
impl FlateAlgorithm for DecompressFlateAlgorithm {
|
||||
fn total_in(&self) -> u64 {
|
||||
self.0.total_in()
|
||||
}
|
||||
|
||||
fn total_out(&self) -> u64 {
|
||||
self.0.total_out()
|
||||
}
|
||||
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error> {
|
||||
self.0
|
||||
.decompress(contents, output, flate2::FlushDecompress::None)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
||||
struct CompressFlateAlgorithm(flate2::Compress);
|
||||
|
||||
impl FlateAlgorithm for CompressFlateAlgorithm {
|
||||
fn total_in(&self) -> u64 {
|
||||
self.0.total_in()
|
||||
}
|
||||
|
||||
fn total_out(&self) -> u64 {
|
||||
self.0.total_out()
|
||||
}
|
||||
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error> {
|
||||
self.0
|
||||
.compress(contents, output, flate2::FlushCompress::Sync)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
||||
struct FlateStream<A>
|
||||
where
|
||||
A: FlateAlgorithm,
|
||||
{
|
||||
flate: A,
|
||||
output: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<A> FlateStream<A>
|
||||
where
|
||||
A: FlateAlgorithm,
|
||||
{
|
||||
pub fn new(alg: A) -> Self {
|
||||
Self {
|
||||
flate: alg,
|
||||
output: vec![0; 4096],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> {
|
||||
let mut out_offset = 0;
|
||||
let mut in_offset = 0;
|
||||
loop {
|
||||
let in_before = self.flate.total_in();
|
||||
let out_before = self.flate.total_out();
|
||||
|
||||
match self
|
||||
.flate
|
||||
.process(&contents[in_offset..], &mut self.output[out_offset..])
|
||||
{
|
||||
Ok(flate2::Status::Ok | flate2::Status::BufError) => {
|
||||
let processed_len = in_offset + (self.flate.total_in() - in_before) as usize;
|
||||
let output_len = out_offset + (self.flate.total_out() - out_before) as usize;
|
||||
if processed_len < contents.len() {
|
||||
// If we filled the output buffer but there's more data to compress,
|
||||
// extend the output buffer and keep compressing.
|
||||
out_offset = output_len;
|
||||
in_offset = processed_len;
|
||||
if output_len == self.output.len() {
|
||||
self.output.resize(self.output.len() * 2, 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok(&self.output[..output_len]);
|
||||
}
|
||||
Ok(flate2::Status::StreamEnd) => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"unexpected stream end",
|
||||
))
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// Note this useful idiom: importing names from outer (for mod tests) scope.
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_round_trips_compression() {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
let mut sink = ServerMessageSink::new_compressed(tx);
|
||||
let mut decompress = ClientMessageDecoder::new_compressed();
|
||||
|
||||
// 3000 and 30000 test resizing the buffer
|
||||
for msg_len in [3, 30, 300, 3000, 30000] {
|
||||
let vals = (0..msg_len).map(|v| v as u8).collect::<Vec<u8>>();
|
||||
let compressed = sink.get_server_msg_content(&vals);
|
||||
assert_ne!(compressed, vals);
|
||||
let decompressed = decompress.decode(compressed).unwrap();
|
||||
assert_eq!(decompressed.len(), vals.len());
|
||||
assert_eq!(decompressed, vals);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue