cli: allow handling control server requests in parallel

Previously the control server could only handle a single request at a
time. To enable local download mode, this needs to change as the client
will be sending data to the CLI as it downloads the vscode server zip.

This does that. There's a little mess since things that async handlers
need to use are cloned out of the previously unified "context" (we
could try and clone the whole context each time, but this is more work
than needed.) We still keep the fast things as "blocking" since that
avoids the need for clones and separate tasks.
This commit is contained in:
Connor Peet 2022-11-08 12:00:31 -08:00
parent 2cd1bf6bd6
commit 161418296b
No known key found for this signature in database
GPG key ID: CF8FD2EA0DBC61BD
4 changed files with 208 additions and 126 deletions

View file

@ -318,17 +318,11 @@ fn detect_installed_program(log: &log::Logger, quality: Quality) -> io::Result<V
}
}
State::LookingForLocation => {
if line.starts_with(LOCATION_PREFIX) {
if let Some(suffix) = line.strip_prefix(LOCATION_PREFIX) {
output.push(
[
&line[LOCATION_PREFIX.len()..].trim(),
"Contents/Resources",
"app",
"bin",
"code",
]
.iter()
.collect(),
[suffix.trim(), "Contents/Resources", "app", "bin", "code"]
.iter()
.collect(),
);
state = State::LookingForName;
}
@ -338,7 +332,7 @@ fn detect_installed_program(log: &log::Logger, quality: Quality) -> io::Result<V
// Sort shorter paths to the front, preferring "more global" installs, and
// incidentally preferring local installs over Parallels 'installs'.
output.sort_by(|a, b| a.as_os_str().len().cmp(&b.as_os_str().len()));
output.sort_by_key(|a| a.as_os_str().len());
Ok(output)
}

View file

@ -24,6 +24,7 @@ use std::fs;
use std::fs::File;
use std::io::{ErrorKind, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::fs::remove_file;
use tokio::io::{AsyncBufReadExt, BufReader};
@ -209,17 +210,19 @@ struct UpdateServerVersion {
}
/// Code server listening on a port address.
#[derive(Clone)]
pub struct SocketCodeServer {
pub commit_id: String,
pub socket: PathBuf,
pub origin: CodeServerOrigin,
pub origin: Arc<CodeServerOrigin>,
}
/// Code server listening on a socket address.
#[derive(Clone)]
pub struct PortCodeServer {
pub commit_id: String,
pub port: u16,
pub origin: CodeServerOrigin,
pub origin: Arc<CodeServerOrigin>,
}
/// A server listening on any address/location.
@ -448,7 +451,7 @@ impl<'a> ServerBuilder<'a> {
)
.await?;
let origin = CodeServerOrigin::Existing(pid);
let origin = Arc::new(CodeServerOrigin::Existing(pid));
let contents = fs::read_to_string(&self.server_paths.logfile)
.expect("Something went wrong reading log file");
@ -544,7 +547,7 @@ impl<'a> ServerBuilder<'a> {
Ok(SocketCodeServer {
commit_id: self.server_params.release.commit.to_owned(),
socket,
origin,
origin: Arc::new(origin),
})
}

View file

@ -42,6 +42,7 @@ use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge
type ServerBridgeList = Option<Vec<(u16, ServerBridge)>>;
type ServerBridgeListLock = Arc<Mutex<ServerBridgeList>>;
type CodeServerCell = Arc<Mutex<Option<SocketCodeServer>>>;
struct HandlerContext {
/// Exit barrier for the socket.
@ -55,7 +56,7 @@ struct HandlerContext {
/// Configured launcher paths.
launcher_paths: LauncherPaths,
/// Connected VS Code Server
code_server: Option<SocketCodeServer>,
code_server: CodeServerCell,
/// Potentially many "websocket" connections to client
server_bridges: ServerBridgeListLock,
// the cli arguments used to start the code server
@ -296,7 +297,7 @@ async fn process_socket(
launcher_paths,
code_server_args,
rx_counter: rx_counter_ctx,
code_server: None,
code_server: Arc::new(Mutex::new(None)),
server_bridges: server_bridges_lock,
port_forwarding,
platform,
@ -379,9 +380,12 @@ async fn handle_socket_read(
let mut did_update = false;
let result = loop {
match read_next(&mut socket_reader, ctx, &mut decode_buf, &mut did_update).await {
Ok(false) => break Ok(()),
Ok(true) => { /* continue */ }
match read_next(&mut socket_reader, ctx, &mut decode_buf).await {
Ok(None) => continue,
Ok(Some(m)) => {
dispatch_next(m, ctx, &mut did_update).await;
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break Ok(()),
Err(e) => break Err(e),
}
};
@ -394,16 +398,16 @@ async fn handle_socket_read(
result
}
/// Reads and handles the next data packet, returns true if the read loop should continue.
/// Reads and handles the next data packet. Returns the next packet to dispatch,
/// or an error (including EOF).
async fn read_next(
socket_reader: &mut BufReader<impl AsyncRead + Unpin>,
ctx: &mut HandlerContext,
decode_buf: &mut Vec<u8>,
did_update: &mut bool,
) -> Result<bool, std::io::Error> {
) -> Result<Option<ToServerRequest>, std::io::Error> {
let msg_length = tokio::select! {
u = socket_reader.read_u32() => u? as usize,
_ = ctx.closer.wait() => return Ok(false),
_ = ctx.closer.wait() => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
};
decode_buf.resize(msg_length, 0);
ctx.rx_counter
@ -411,17 +415,21 @@ async fn read_next(
tokio::select! {
r = socket_reader.read_exact(decode_buf) => r?,
_ = ctx.closer.wait() => return Ok(false),
_ = ctx.closer.wait() => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
};
let req = match rmp_serde::from_slice::<ToServerRequest>(decode_buf) {
Ok(req) => req,
match rmp_serde::from_slice::<ToServerRequest>(decode_buf) {
Ok(req) => Ok(Some(req)),
Err(e) => {
warning!(ctx.log, "Error decoding message: {}", e);
return Ok(true); // not fatal
Ok(None) // not fatal
}
};
}
}
// Dispatches a server request. Returns `true` if the socket reading should
// continue,
async fn dispatch_next(req: ToServerRequest, ctx: &mut HandlerContext, did_update: &mut bool) {
let log = ctx.log.prefixed(
req.id
.map(|id| format!("[call.{}]", id))
@ -429,70 +437,137 @@ async fn read_next(
.unwrap_or("notify"),
);
macro_rules! success {
($r:expr) => {
req.id
.map(|id| rmp_serde::to_vec_named(&SuccessResponse { id, result: &$r }))
macro_rules! send {
($tx:expr, $res:expr) => {
if let Some(Ok(res)) = $res {
$tx.send(SocketSignal::Send(res)).await.is_err()
} else {
false
}
};
}
macro_rules! tj {
($name:expr, $e:expr) => {
macro_rules! success {
($tx:expr, $r:expr) => {
send!(
$tx,
req.id
.map(|id| rmp_serde::to_vec_named(&SuccessResponse { id, result: &$r }))
)
};
}
macro_rules! dispatch_raw {
($log:expr, $socket_tx:expr, $name:expr, $e:expr) => {
match (spanf!(
log,
log.span(&format!("call.{}", $name))
$log,
$log.span(&format!("call.{}", $name))
.with_kind(opentelemetry::trace::SpanKind::Server),
$e
)) {
Ok(r) => success!(r),
Ok(r) => success!($socket_tx, r),
Err(e) => {
warning!(log, "error handling call: {:?}", e);
req.id.map(|id| {
rmp_serde::to_vec_named(&ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{:?}", e),
},
warning!($log, "error handling call: {:?}", e);
send!(
$socket_tx,
req.id.map(|id| {
rmp_serde::to_vec_named(&ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{:?}", e),
},
})
})
})
)
}
}
};
}
let response = match req.params {
ServerRequestMethod::ping(_) => success!(EmptyResult {}),
ServerRequestMethod::serve(p) => tj!("serve", handle_serve(ctx, &log, p)),
ServerRequestMethod::prune => tj!("prune", handle_prune(ctx)),
ServerRequestMethod::gethostname(_) => tj!("gethostname", handle_get_hostname()),
ServerRequestMethod::update(p) => tj!("update", async {
let r = handle_update(ctx, &p).await;
if matches!(&r, Ok(u) if u.did_update) {
*did_update = true;
}
r
}),
ServerRequestMethod::servermsg(m) => {
if let Err(e) = handle_server_message(ctx, m).await {
warning!(log, "error handling call: {:?}", e);
}
None
}
ServerRequestMethod::callserverhttp(p) => {
tj!("callserverhttp", handle_call_server_http(ctx, p))
}
ServerRequestMethod::forward(p) => tj!("forward", handle_forward(ctx, p)),
ServerRequestMethod::unforward(p) => tj!("unforward", handle_unforward(ctx, p)),
};
if let Some(Ok(res)) = response {
if ctx.socket_tx.send(SocketSignal::Send(res)).await.is_err() {
return Ok(false);
}
// Runs the $e expression synchronously, returning its Result to the socket.
// This should only be used for fast-returning functions, otherwise prefer
// dispatch_async.
macro_rules! dispatch_blocking {
($name:expr, $e:expr) => {
dispatch_raw!(ctx.log, ctx.socket_tx, $name, $e);
};
}
Ok(true)
// Runs the $e expression asynchronously, returning its Result to the socket.
macro_rules! dispatch_async {
($name:expr, $e:expr) => {
let socket_tx = ctx.socket_tx.clone();
let span_logger = ctx.log.clone();
tokio::spawn(async move { dispatch_raw!(span_logger, socket_tx, $name, $e) })
};
}
match req.params {
ServerRequestMethod::ping(_) => {
success!(ctx.socket_tx, EmptyResult {});
}
ServerRequestMethod::serve(params) => {
let log = ctx.log.clone();
let server_bridges = ctx.server_bridges.clone();
let code_server_args = ctx.code_server_args.clone();
let code_server = ctx.code_server.clone();
let platform = ctx.platform;
let socket_tx = ctx.socket_tx.clone();
let paths = ctx.launcher_paths.clone();
dispatch_async!(
"serve",
handle_serve(
log,
server_bridges,
code_server_args,
platform,
code_server,
socket_tx,
paths,
params
)
);
}
ServerRequestMethod::prune => {
let paths = ctx.launcher_paths.clone();
dispatch_blocking!("prune", handle_prune(&paths));
}
ServerRequestMethod::gethostname(_) => {
dispatch_blocking!("gethostname", handle_get_hostname());
}
ServerRequestMethod::update(p) => {
dispatch_blocking!("update", async {
let r = handle_update(&ctx.log, &p).await;
if matches!(&r, Ok(u) if u.did_update) {
*did_update = true;
}
r
});
}
ServerRequestMethod::servermsg(m) => {
// It's important this this is not dispatch_async'd, since otherwise
// the order of servermsg's could be switched, which could lead to errors.
let bridges_lock = ctx.server_bridges.clone();
if let Err(e) = handle_server_message(bridges_lock, m).await {
warning!(log, "error handling call: {:?}", e);
}
}
ServerRequestMethod::callserverhttp(p) => {
let code_server = ctx.code_server.lock().await.clone();
dispatch_async!("callserverhttp", handle_call_server_http(code_server, p));
}
ServerRequestMethod::forward(p) => {
let log = ctx.log.clone();
let port_forwarding = ctx.port_forwarding.clone();
dispatch_async!("forward", handle_forward(log, port_forwarding, p));
}
ServerRequestMethod::unforward(p) => {
let log = ctx.log.clone();
let port_forwarding = ctx.port_forwarding.clone();
dispatch_async!("unforward", handle_unforward(log, port_forwarding, p));
}
};
}
#[derive(Clone)]
@ -516,13 +591,17 @@ impl log::LogSink for ServerOutputSink {
fn write_result(&self, _message: &str) {}
}
#[allow(clippy::too_many_arguments)]
async fn handle_serve(
ctx: &mut HandlerContext,
log: &log::Logger,
log: log::Logger,
server_bridges: ServerBridgeListLock,
mut code_server_args: CodeServerArgs,
platform: Platform,
code_server: CodeServerCell,
socket_tx: mpsc::Sender<SocketSignal>,
launcher_paths: LauncherPaths,
params: ServeParams,
) -> Result<EmptyResult, AnyError> {
let mut code_server_args = ctx.code_server_args.clone();
// fill params.extensions into code_server_args.install_extensions
code_server_args
.install_extensions
@ -533,49 +612,55 @@ async fn handle_serve(
quality: params.quality,
code_server_args,
headless: true,
platform: ctx.platform,
platform,
}
.resolve(log)
.resolve(&log)
.await?;
if ctx.code_server.is_none() {
let install_log = log.tee(ServerOutputSink {
tx: ctx.socket_tx.clone(),
});
let sb = ServerBuilder::new(&install_log, &resolved, &ctx.launcher_paths);
let mut server_ref = code_server.lock().await;
let server = match &*server_ref {
Some(o) => o.clone(),
None => {
let install_log = log.tee(ServerOutputSink {
tx: socket_tx.clone(),
});
let sb = ServerBuilder::new(&install_log, &resolved, &launcher_paths);
let server = match sb.get_running().await? {
Some(AnyCodeServer::Socket(s)) => s,
Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())),
None => {
sb.setup().await?;
sb.listen_on_default_socket().await?
}
};
let server = match sb.get_running().await? {
Some(AnyCodeServer::Socket(s)) => s,
Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())),
None => {
sb.setup().await?;
sb.listen_on_default_socket().await?
}
};
ctx.code_server = Some(server);
}
server_ref.replace(server.clone());
server
}
};
attach_server_bridge(ctx, params.socket_id).await?;
attach_server_bridge(&log, server, socket_tx, server_bridges, params.socket_id).await?;
Ok(EmptyResult {})
}
async fn attach_server_bridge(ctx: &mut HandlerContext, socket_id: u16) -> Result<u16, AnyError> {
let attached_fut = ServerBridge::new(
&ctx.code_server.as_ref().unwrap().socket,
socket_id,
&ctx.socket_tx,
)
.await;
async fn attach_server_bridge(
log: &log::Logger,
code_server: SocketCodeServer,
socket_tx: mpsc::Sender<SocketSignal>,
server_bridges: ServerBridgeListLock,
socket_id: u16,
) -> Result<u16, AnyError> {
let attached_fut = ServerBridge::new(&code_server.socket, socket_id, &socket_tx).await;
match attached_fut {
Ok(a) => {
let mut lock = ctx.server_bridges.lock().await;
let mut lock = server_bridges.lock().await;
match &mut *lock {
Some(server_bridges) => (*server_bridges).push((socket_id, a)),
None => *lock = Some(vec![(socket_id, a)]),
}
trace!(ctx.log, "Attached to server");
trace!(log, "Attached to server");
Ok(socket_id)
}
Err(e) => Err(e),
@ -583,10 +668,10 @@ async fn attach_server_bridge(ctx: &mut HandlerContext, socket_id: u16) -> Resul
}
async fn handle_server_message(
ctx: &mut HandlerContext,
bridges_lock: ServerBridgeListLock,
params: ServerMessageParams,
) -> Result<EmptyResult, AnyError> {
let mut lock = ctx.server_bridges.lock().await;
let mut lock = bridges_lock.lock().await;
match &mut *lock {
Some(server_bridges) => {
@ -606,19 +691,16 @@ async fn handle_server_message(
Ok(EmptyResult {})
}
async fn handle_prune(ctx: &HandlerContext) -> Result<Vec<String>, AnyError> {
prune_stopped_servers(&ctx.launcher_paths).map(|v| {
async fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {
prune_stopped_servers(paths).map(|v| {
v.iter()
.map(|p| p.server_dir.display().to_string())
.collect()
})
}
async fn handle_update(
ctx: &HandlerContext,
params: &UpdateParams,
) -> Result<UpdateResult, AnyError> {
let update_service = UpdateService::new(ctx.log.clone(), reqwest::Client::new());
async fn handle_update(log: &log::Logger, params: &UpdateParams) -> Result<UpdateResult, AnyError> {
let update_service = UpdateService::new(log.clone(), reqwest::Client::new());
let updater = SelfUpdate::new(&update_service)?;
let latest_release = updater.get_current_release().await?;
let up_to_date = updater.is_up_to_date_with(&latest_release);
@ -630,7 +712,7 @@ async fn handle_update(
});
}
info!(ctx.log, "Updating CLI to {}", latest_release);
info!(log, "Updating CLI to {}", latest_release);
updater
.do_update(&latest_release, SilentCopyProgress())
@ -649,25 +731,27 @@ async fn handle_get_hostname() -> Result<GetHostnameResponse, Infallible> {
}
async fn handle_forward(
ctx: &HandlerContext,
log: log::Logger,
port_forwarding: PortForwarding,
params: ForwardParams,
) -> Result<ForwardResult, AnyError> {
info!(ctx.log, "Forwarding port {}", params.port);
let uri = ctx.port_forwarding.forward(params.port).await?;
info!(log, "Forwarding port {}", params.port);
let uri = port_forwarding.forward(params.port).await?;
Ok(ForwardResult { uri })
}
async fn handle_unforward(
ctx: &HandlerContext,
log: log::Logger,
port_forwarding: PortForwarding,
params: UnforwardParams,
) -> Result<EmptyResult, AnyError> {
info!(ctx.log, "Unforwarding port {}", params.port);
ctx.port_forwarding.unforward(params.port).await?;
info!(log, "Unforwarding port {}", params.port);
port_forwarding.unforward(params.port).await?;
Ok(EmptyResult {})
}
async fn handle_call_server_http(
ctx: &HandlerContext,
code_server: Option<SocketCodeServer>,
params: CallServerHttpParams,
) -> Result<CallServerHttpResult, AnyError> {
use hyper::{body, client::conn::Builder, Body, Request};
@ -675,7 +759,7 @@ async fn handle_call_server_http(
// We use Hyper directly here since reqwest doesn't support sockets/pipes.
// See https://github.com/seanmonstar/reqwest/issues/39
let socket = match &ctx.code_server {
let socket = match &code_server {
Some(cs) => &cs.socket,
None => return Err(AnyError::from(NoAttachedServerError())),
};

View file

@ -95,6 +95,7 @@ impl PortForwardingProcessor {
}
}
#[derive(Clone)]
pub struct PortForwarding {
tx: mpsc::Sender<PortForwardingRec>,
}