cli: add streams to rpc, generic 'spawn' command (#179732)

* cli: apply improvements from integrated wsl branch

* cli: add streams to rpc, generic 'spawn' command

For the "exec server" concept, fyi @aeschli.

* update clippy and apply fixes

* fix unused imports :(
This commit is contained in:
Connor Peet 2023-04-12 08:51:29 -07:00 committed by GitHub
parent bb7570f4f8
commit 2d8ff25c85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 572 additions and 184 deletions

5
cli/Cargo.lock generated
View file

@ -146,9 +146,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "1.2.1"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
[[package]]
name = "cache-padded"
@ -230,6 +230,7 @@ dependencies = [
"async-trait",
"atty",
"base64",
"bytes",
"cfg-if",
"chrono",
"clap",

View file

@ -17,7 +17,7 @@ clap = { version = "3.0", features = ["derive", "env"] }
open = { version = "2.1.0" }
reqwest = { version = "0.11.9", default-features = false, features = ["json", "stream", "native-tls"] }
tokio = { version = "1.24.2", features = ["full"] }
tokio-util = { version = "0.7", features = ["compat"] }
tokio-util = { version = "0.7", features = ["compat", "codec"] }
flate2 = { version = "1.0.22" }
zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] }
regex = { version = "1.5.5" }
@ -54,6 +54,7 @@ thiserror = "1.0"
cfg-if = "1.0.0"
pin-project = "1.0"
console = "0.15"
bytes = "1.4"
[build-dependencies]
serde = { version = "1.0" }

View file

@ -190,7 +190,7 @@ pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Resul
let auth = Auth::new(&ctx.paths, ctx.log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
dt.rename_tunnel(&rename_args.name).await?;
ctx.log.result(&format!(
ctx.log.result(format!(
"Successfully renamed this gateway to {}",
&rename_args.name
));
@ -287,7 +287,7 @@ pub async fn prune(ctx: CommandContext) -> Result<i32, AnyError> {
.filter(|s| s.get_running_pid().is_none())
.try_for_each(|s| {
ctx.log
.result(&format!("Deleted {}", s.server_dir.display()));
.result(format!("Deleted {}", s.server_dir.display()));
s.delete()
})
.map_err(AnyError::from)?;

View file

@ -3,6 +3,8 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::sync::Arc;
use indicatif::ProgressBar;
use crate::{
@ -17,7 +19,7 @@ use super::{args::StandaloneUpdateArgs, CommandContext};
pub async fn update(ctx: CommandContext, args: StandaloneUpdateArgs) -> Result<i32, AnyError> {
let update_service = UpdateService::new(
ctx.log.clone(),
ReqwestSimpleHttp::with_client(ctx.http.clone()),
Arc::new(ReqwestSimpleHttp::with_client(ctx.http.clone())),
);
let update_service = SelfUpdate::new(&update_service)?;

View file

@ -58,5 +58,5 @@ pub async fn show(ctx: CommandContext) -> Result<i32, AnyError> {
}
fn print_now_using(log: &log::Logger, version: &RequestedVersion, path: &Path) {
log.result(&format!("Now using {} from {}", version, path.display()));
log.result(format!("Now using {} from {}", version, path.display()));
}

View file

@ -50,7 +50,7 @@ pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
let mut read = BufReader::new(read);
let mut read_buf = String::new();
@ -84,7 +84,18 @@ pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
}
});
},
MaybeSync::Stream((dto, fut)) => {
if let Some(dto) = dto {
dispatcher.register_stream(write_tx.clone(), dto).await;
}
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
let _ = write_tx.send(v).await;
}
});
}

View file

@ -27,21 +27,19 @@ pub fn next_counter() -> u32 {
// Log level
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)]
#[derive(Default)]
pub enum Level {
Trace = 0,
Debug,
Info,
#[default]
Info,
Warn,
Error,
Critical,
Off,
}
impl Default for Level {
fn default() -> Self {
Level::Info
}
}
impl fmt::Display for Level {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {

View file

@ -8,6 +8,7 @@ use tokio::{
pin,
sync::mpsc,
};
use tokio_util::codec::Decoder;
use crate::{
rpc::{self, MaybeSync, Serialization},
@ -38,7 +39,6 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
rpc::RpcBuilder::new(MsgPackSerializer {})
}
#[allow(clippy::read_zero_byte_vec)] // false positive
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
dispatcher: rpc::RpcDispatcher<MsgPackSerializer, C>,
read: impl AsyncRead + Unpin,
@ -46,34 +46,45 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
let mut read = BufReader::new(read);
let mut decode_buf = vec![];
let mut decoder = U32PrefixedCodec {};
let mut decoder_buf = bytes::BytesMut::new();
let shutdown_fut = shutdown_rx.wait();
pin!(shutdown_fut);
loop {
tokio::select! {
u = read.read_u32() => {
let msg_length = u? as usize;
decode_buf.resize(msg_length, 0);
tokio::select! {
r = read.read_exact(&mut decode_buf) => match dispatcher.dispatch(&decode_buf[..r?]) {
r = read.read_buf(&mut decoder_buf) => {
r?;
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
match dispatcher.dispatch(&frame) {
MaybeSync::Sync(Some(v)) => {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
},
MaybeSync::Sync(None) => continue,
MaybeSync::Future(fut) => {
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
}
});
}
},
r = &mut shutdown_fut => return Ok(r.ok()),
MaybeSync::Stream((stream, fut)) => {
if let Some(stream) = stream {
dispatcher.register_stream(write_tx.clone(), stream).await;
}
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
let _ = write_tx.send(v).await;
}
});
}
}
};
},
Some(m) = write_rx.recv() => {
@ -88,3 +99,33 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
write.flush().await?;
}
}
/// Reader that reads length-prefixed msgpack messages in a cancellation-safe
/// way using Tokio's codecs.
pub struct U32PrefixedCodec {}
const U32_SIZE: usize = 4;
impl tokio_util::codec::Decoder for U32PrefixedCodec {
type Item = Vec<u8>;
type Error = io::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
src.reserve(U32_SIZE - src.len());
return Ok(None);
}
let mut be_bytes = [0; U32_SIZE];
be_bytes.copy_from_slice(&src[..U32_SIZE]);
let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize);
if src.len() < required_len {
src.reserve(required_len - src.len());
return Ok(None);
}
let msg = src[U32_SIZE..].to_vec();
src.resize(0, 0);
Ok(Some(msg))
}
}

View file

@ -15,17 +15,26 @@ use std::{
use crate::log;
use futures::{future::BoxFuture, Future, FutureExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
sync::{mpsc, oneshot},
};
use crate::util::errors::AnyError;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
pub type AsyncMethod =
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
pub type Duplex = Arc<
dyn Send
+ Sync
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
>;
pub enum Method {
Sync(SyncMethod),
Async(AsyncMethod),
Duplex(Duplex),
}
/// Serialization is given to the RpcBuilder and defines how data gets serialized
@ -81,6 +90,12 @@ pub struct RpcMethodBuilder<S, C> {
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
}
#[derive(Serialize)]
struct DuplexStreamStarted {
pub for_request_id: u32,
pub stream_id: u32,
}
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
/// Registers a synchronous rpc call that returns its result directly.
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
@ -179,14 +194,105 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
);
}
/// Registers an async rpc call that returns a Future containing a duplex
/// stream that should be handled by the client.
pub fn register_duplex<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
where
P: DeserializeOwned + Send + 'static,
R: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<R, AnyError>> + Send,
F: (Fn(DuplexStream, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
{
let serial = self.serializer.clone();
let context = self.context.clone();
self.methods.insert(
method_name,
Method::Duplex(Arc::new(move |id, body| {
let param = match serial.deserialize::<RequestParams<P>>(body) {
Ok(p) => p,
Err(err) => {
return (
None,
future::ready(id.map(|id| {
serial.serialize(&ErrorResponse {
id,
error: ResponseError {
code: 0,
message: format!("{:?}", err),
},
})
}))
.boxed(),
);
}
};
let callback = callback.clone();
let serial = serial.clone();
let context = context.clone();
let stream_id = next_message_id();
let (client, server) = tokio::io::duplex(8192);
let fut = async move {
match callback(server, param.params, context).await {
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
Err(err) => id.map(|id| {
serial.serialize(&ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{:?}", err),
},
})
}),
}
};
(
Some(StreamDto {
req_id: id.unwrap_or(0),
stream_id,
duplex: client,
}),
fut.boxed(),
)
})),
);
}
/// Builds into a usable, sync rpc dispatcher.
pub fn build(self, log: log::Logger) -> RpcDispatcher<S, C> {
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
let streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>> =
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let s1 = streams.clone();
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
let s1 = s1.clone();
async move {
s1.lock().await.remove(&m.stream);
Ok(())
}
});
let s2 = streams.clone();
self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
let s2 = s2.clone();
async move {
let mut lock = s2.lock().await;
if let Some(stream) = lock.get_mut(&m.stream) {
let _ = stream.write_all(&m.segment).await;
}
Ok(())
}
});
RpcDispatcher {
log,
context: self.context,
calls: self.calls,
serializer: self.serializer,
methods: Arc::new(self.methods),
streams,
}
}
}
@ -281,6 +387,7 @@ pub struct RpcDispatcher<S, C> {
serializer: Arc<S>,
methods: Arc<HashMap<&'static str, Method>>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>>,
}
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
@ -310,6 +417,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
match method {
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
None => MaybeSync::Sync(id.map(|id| {
self.serializer.serialize(&ErrorResponse {
id,
@ -333,11 +441,91 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
}
}
/// Registers a stream call returned from dispatch().
pub async fn register_stream(
&self,
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
dto: StreamDto,
) {
let stream_id = dto.stream_id;
let for_request_id = dto.req_id;
let (mut read, write) = tokio::io::split(dto.duplex);
let serial = self.serializer.clone();
self.streams.lock().await.insert(dto.stream_id, write);
tokio::spawn(async move {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_STARTED,
params: DuplexStreamStarted {
stream_id,
for_request_id,
},
})
.into(),
)
.await;
if r.is_err() {
return;
}
let mut buf = Vec::with_capacity(4096);
loop {
match read.read_buf(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_DATA,
params: StreamDataParams {
segment: &buf[..n],
stream: stream_id,
},
})
.into(),
)
.await;
if r.is_err() {
return;
}
buf.truncate(0);
}
}
}
let _ = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_ENDED,
params: StreamEndedParams { stream: stream_id },
})
.into(),
)
.await;
});
}
pub fn context(&self) -> Arc<C> {
self.context.clone()
}
}
const METHOD_STREAM_STARTED: &str = "stream_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";
trait AssertIsSync: Sync {}
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
@ -349,6 +537,25 @@ struct PartialIncoming {
pub error: Option<ResponseError>,
}
#[derive(Deserialize)]
struct StreamDataIncomingParams {
#[serde(with = "serde_bytes")]
pub segment: Vec<u8>,
pub stream: u32,
}
#[derive(Serialize, Deserialize)]
struct StreamDataParams<'a> {
#[serde(with = "serde_bytes")]
pub segment: &'a [u8],
pub stream: u32,
}
#[derive(Serialize, Deserialize)]
struct StreamEndedParams {
pub stream: u32,
}
#[derive(Serialize)]
pub struct FullRequest<M: AsRef<str>, P> {
pub id: Option<u32>,
@ -384,7 +591,14 @@ enum Outcome {
Error(ResponseError),
}
pub struct StreamDto {
stream_id: u32,
req_id: u32,
duplex: DuplexStream,
}
pub enum MaybeSync {
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
Future(BoxFuture<'static, Option<Vec<u8>>>),
Sync(Option<Vec<u8>>),
}

View file

@ -86,8 +86,8 @@ impl<'a> SelfUpdate<'a> {
// Try to rename the old CLI to the tempdir, where it can get cleaned up by the
// OS later. However, this can fail if the tempdir is on a different drive
// than the installation dir. In this case just rename it to ".old".
if fs::rename(&target_path, &tempdir.path().join("old-code-cli")).is_err() {
fs::rename(&target_path, &target_path.with_extension(".old"))
if fs::rename(&target_path, tempdir.path().join("old-code-cli")).is_err() {
fs::rename(&target_path, target_path.with_extension(".old"))
.map_err(|e| wrap(e, "failed to rename old CLI"))?;
}
@ -132,7 +132,7 @@ fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Res
let archive_file = unzipped_files[0]
.as_ref()
.map_err(|e| wrap(e, "error listing update files"))?;
fs::copy(&archive_file.path(), staging_path)
fs::copy(archive_file.path(), staging_path)
.map_err(|e| wrap(e, "error copying to staging file"))?;
Ok(())
}
@ -140,7 +140,7 @@ fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Res
#[cfg(target_os = "windows")]
fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> {
let permissions = from.metadata()?.permissions();
fs::set_permissions(&to, permissions)?;
fs::set_permissions(to, permissions)?;
Ok(())
}

View file

@ -16,7 +16,7 @@ use crate::util::command::{capture_command, kill_tree};
use crate::util::errors::{
wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError,
};
use crate::util::http::{self, SimpleHttp};
use crate::util::http::{self, BoxedHttp};
use crate::util::io::SilentCopyProgress;
use crate::util::machine::process_exists;
use crate::{debug, info, log, span, spanf, trace, warning};
@ -176,7 +176,7 @@ impl ServerParamsRaw {
pub async fn resolve(
self,
log: &log::Logger,
http: impl SimpleHttp + Send + Sync + 'static,
http: BoxedHttp,
) -> Result<ResolvedServerParams, AnyError> {
Ok(ResolvedServerParams {
release: self.get_or_fetch_commit_id(log, http).await?,
@ -187,7 +187,7 @@ impl ServerParamsRaw {
async fn get_or_fetch_commit_id(
&self,
log: &log::Logger,
http: impl SimpleHttp + Send + Sync + 'static,
http: BoxedHttp,
) -> Result<Release, AnyError> {
let target = match self.headless {
true => TargetKind::Server,
@ -287,7 +287,7 @@ async fn install_server_if_needed(
log: &log::Logger,
paths: &ServerPaths,
release: &Release,
http: impl SimpleHttp + Send + Sync + 'static,
http: BoxedHttp,
existing_archive_path: Option<PathBuf>,
) -> Result<(), AnyError> {
if paths.executable.exists() {
@ -321,7 +321,7 @@ async fn download_server(
path: &Path,
release: &Release,
log: &log::Logger,
http: impl SimpleHttp + Send + Sync + 'static,
http: BoxedHttp,
) -> Result<PathBuf, AnyError> {
let response = UpdateService::new(log.clone(), http)
.get_download_stream(release)
@ -403,20 +403,20 @@ async fn do_extension_install_on_running_server(
}
}
pub struct ServerBuilder<'a, Http: SimpleHttp + Send + Sync + Clone> {
pub struct ServerBuilder<'a> {
logger: &'a log::Logger,
server_params: &'a ResolvedServerParams,
last_used: LastUsedServers<'a>,
server_paths: ServerPaths,
http: Http,
http: BoxedHttp,
}
impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Http> {
impl<'a> ServerBuilder<'a> {
pub fn new(
logger: &'a log::Logger,
server_params: &'a ResolvedServerParams,
launcher_paths: &'a LauncherPaths,
http: Http,
http: BoxedHttp,
) -> Self {
Self {
logger,

View file

@ -5,6 +5,7 @@
use crate::async_pipe::get_socket_rw_stream;
use crate::constants::CONTROL_PORT;
use crate::log;
use crate::msgpack_rpc::U32PrefixedCodec;
use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization};
use crate::self_update::SelfUpdate;
use crate::state::LauncherPaths;
@ -12,7 +13,8 @@ use crate::tunnels::protocol::HttpRequestParams;
use crate::tunnels::socket_signal::CloseReason;
use crate::update_service::{Platform, UpdateService};
use crate::util::errors::{
wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError,
wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError,
NoAttachedServerError,
};
use crate::util::http::{
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
@ -24,11 +26,14 @@ use crate::util::sync::{new_barrier, Barrier};
use opentelemetry::trace::SpanKind;
use opentelemetry::KeyValue;
use std::collections::HashMap;
use std::process::Stdio;
use tokio::pin;
use tokio_util::codec::Decoder;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream};
use tokio::sync::{mpsc, Mutex};
use super::code_server::{
@ -40,8 +45,8 @@ use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, ForwardParams,
ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog,
ServerMessageParams, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult,
VersionParams,
ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams,
UpdateResult, VersionParams,
};
use super::server_bridge::ServerBridge;
use super::server_multiplexer::ServerMultiplexer;
@ -73,7 +78,7 @@ struct HandlerContext {
/// install platform for the VS Code server
platform: Platform,
/// http client to make download/update requests
http: FallbackSimpleHttp,
http: Arc<FallbackSimpleHttp>,
/// requests being served by the client
http_requests: HttpRequestsMap,
}
@ -196,7 +201,7 @@ pub async fn serve(
],
);
cx.span().end();
});
});
}
}
}
@ -247,7 +252,10 @@ async fn process_socket(
server_bridges: server_bridges.clone(),
port_forwarding,
platform,
http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated),
http: Arc::new(FallbackSimpleHttp::new(
ReqwestSimpleHttp::new(),
http_delegated,
)),
http_requests: http_requests.clone(),
});
@ -276,6 +284,9 @@ async fn process_socket(
rpc.register_async("unforward", |p: UnforwardParams, c| async move {
handle_unforward(&c.log, &c.port_forwarding, p).await
});
rpc.register_duplex("spawn", |stream, p: SpawnParams, c| async move {
handle_spawn(&c.log, stream, p).await
});
rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| {
if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) {
req.initial_response(p.status_code, p.headers);
@ -393,20 +404,20 @@ async fn handle_socket_read(
rx_counter: Arc<AtomicUsize>,
rpc: &RpcDispatcher<MsgPackSerializer, HandlerContext>,
) -> Result<(), std::io::Error> {
let mut socket_reader = BufReader::new(readhalf);
let mut decode_buf = vec![];
let mut readhalf = BufReader::new(readhalf);
let mut decoder = U32PrefixedCodec {};
let mut decoder_buf = bytes::BytesMut::new();
loop {
let read = read_next(
&mut socket_reader,
&rx_counter,
&mut closer,
&mut decode_buf,
)
.await;
let read_len = tokio::select! {
r = readhalf.read_buf(&mut decoder_buf) => r,
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
}?;
match read {
Ok(len) => match rpc.dispatch(&decode_buf[..len]) {
rx_counter.fetch_add(read_len, Ordering::Relaxed);
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
match rpc.dispatch(&frame) {
MaybeSync::Sync(Some(v)) => {
if socket_tx.send(SocketSignal::Send(v)).await.is_err() {
return Ok(());
@ -421,34 +432,22 @@ async fn handle_socket_read(
}
});
}
},
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
Err(e) => return Err(e),
MaybeSync::Stream((stream, fut)) => {
if let Some(stream) = stream {
rpc.register_stream(socket_tx.clone(), stream).await;
}
let socket_tx = socket_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
socket_tx.send(SocketSignal::Send(v)).await.ok();
}
});
}
}
}
}
}
/// Reads and handles the next data packet. Returns the next packet to dispatch,
/// or an error (including EOF).
async fn read_next(
socket_reader: &mut BufReader<impl AsyncRead + Unpin>,
rx_counter: &Arc<AtomicUsize>,
closer: &mut Barrier<()>,
decode_buf: &mut Vec<u8>,
) -> Result<usize, std::io::Error> {
let msg_length = tokio::select! {
u = socket_reader.read_u32() => u? as usize,
_ = closer.wait() => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
};
decode_buf.resize(msg_length, 0);
rx_counter.fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed);
tokio::select! {
r = socket_reader.read_exact(decode_buf) => r,
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
}
}
#[derive(Clone)]
struct ServerOutputSink {
tx: mpsc::Sender<SocketSignal>,
@ -487,7 +486,9 @@ async fn handle_serve(
};
let resolved = if params.use_local_download {
params_raw.resolve(&c.log, c.http.delegated()).await
params_raw
.resolve(&c.log, Arc::new(c.http.delegated()))
.await
} else {
params_raw.resolve(&c.log, c.http.clone()).await
}?;
@ -518,7 +519,7 @@ async fn handle_serve(
&install_log,
&resolved,
&c.launcher_paths,
c.http.delegated(),
Arc::new(c.http.delegated()),
);
do_setup!(sb)
} else {
@ -606,7 +607,7 @@ fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {
}
async fn handle_update(
http: &FallbackSimpleHttp,
http: &Arc<FallbackSimpleHttp>,
log: &log::Logger,
did_update: &AtomicBool,
params: &UpdateParams,
@ -732,3 +733,83 @@ async fn handle_call_server_http(
.to_vec(),
})
}
async fn handle_spawn(
log: &log::Logger,
mut duplex: DuplexStream,
params: SpawnParams,
) -> Result<SpawnResult, AnyError> {
debug!(
log,
"requested to spawn {} with args {:?}", params.command, params.args
);
let mut p = tokio::process::Command::new(&params.command)
.args(&params.args)
.envs(&params.env)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(CodeError::ProcessSpawnFailed)?;
let mut stdout = p.stdout.take().unwrap();
let mut stderr = p.stderr.take().unwrap();
let mut stdin = p.stdin.take().unwrap();
let (tx, mut rx) = mpsc::channel(4);
macro_rules! copy_stream_to {
($target:expr) => {
let tx = tx.clone();
tokio::spawn(async move {
let mut buf = vec![0; 4096];
loop {
let n = match $target.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
if !tx.send(buf[..n].to_vec()).await.is_ok() {
return;
}
}
});
};
}
copy_stream_to!(stdout);
copy_stream_to!(stderr);
let mut stdin_buf = vec![0; 4096];
let closed = p.wait();
pin!(closed);
loop {
tokio::select! {
Ok(n) = duplex.read(&mut stdin_buf) => {
let _ = stdin.write_all(&stdin_buf[..n]).await;
},
Some(m) = rx.recv() => {
let _ = duplex.write_all(&m).await;
},
r = &mut closed => {
let r = match r {
Ok(e) => SpawnResult {
message: e.to_string(),
exit_code: e.code().unwrap_or(-1),
},
Err(e) => SpawnResult {
message: e.to_string(),
exit_code: -1,
},
};
debug!(
log,
"spawned command {} exited with code {}", params.command, r.exit_code
);
return Ok(r)
},
}
}
}

View file

@ -68,7 +68,7 @@ impl ServerPaths {
// VS Code Server pid
pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> {
write(&self.pidfile, &format!("{}", pid)).map_err(|e| {
write(&self.pidfile, format!("{}", pid)).map_err(|e| {
wrap(
e,
format!("error writing process id into {}", self.pidfile.display()),

View file

@ -158,6 +158,20 @@ impl Default for VersionParams {
}
}
#[derive(Deserialize)]
pub struct SpawnParams {
pub command: String,
pub args: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
}
#[derive(Serialize)]
pub struct SpawnResult {
pub message: String,
pub exit_code: i32,
}
pub mod singleton {
use crate::log;
use serde::{Deserialize, Serialize};

View file

@ -59,7 +59,7 @@ impl CliServiceManager for WindowsService {
};
for arg in args {
add_arg(*arg);
add_arg(arg);
}
add_arg("--log-to-file");

View file

@ -22,6 +22,12 @@ pub enum SocketSignal {
CloseWith(CloseReason),
}
impl From<Vec<u8>> for SocketSignal {
fn from(v: Vec<u8>) -> Self {
SocketSignal::Send(v)
}
}
impl SocketSignal {
pub fn from_message<T>(msg: &T) -> Self
where

View file

@ -3,6 +3,8 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::{
@ -139,7 +141,12 @@ async fn handle_serve(
},
};
let sb = ServerBuilder::new(&c.log, &resolved, &c.launcher_paths, c.http.clone());
let sb = ServerBuilder::new(
&c.log,
&resolved,
&c.launcher_paths,
Arc::new(c.http.clone()),
);
let code_server = match sb.get_running().await? {
Some(AnyCodeServer::Socket(s)) => s,
Some(_) => return Err(MismatchedLaunchModeError().into()),

View file

@ -3,7 +3,7 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::path::Path;
use std::{fmt, path::Path};
use serde::Deserialize;
@ -11,19 +11,20 @@ use crate::{
constants::VSCODE_CLI_UPDATE_ENDPOINT,
debug, log, options, spanf,
util::{
errors::{AnyError, UnsupportedPlatformError, UpdatesNotConfigured, WrappedError},
http::{SimpleHttp, SimpleResponse},
errors::{AnyError, CodeError, UpdatesNotConfigured, WrappedError},
http::{BoxedHttp, SimpleResponse},
io::ReportCopyProgress,
},
};
/// Implementation of the VS Code Update service for use in the CLI.
pub struct UpdateService {
client: Box<dyn SimpleHttp + Send + Sync + 'static>,
client: BoxedHttp,
log: log::Logger,
}
/// Describes a specific release, can be created manually or returned from the update service.
#[derive(Clone, Eq, PartialEq)]
pub struct Release {
pub name: String,
pub platform: Platform,
@ -53,11 +54,8 @@ fn quality_download_segment(quality: options::Quality) -> &'static str {
}
impl UpdateService {
pub fn new(log: log::Logger, http: impl SimpleHttp + Send + Sync + 'static) -> Self {
UpdateService {
client: Box::new(http),
log,
}
pub fn new(log: log::Logger, http: BoxedHttp) -> Self {
UpdateService { client: http, log }
}
pub async fn get_release_by_semver_version(
@ -71,7 +69,7 @@ impl UpdateService {
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
let download_segment = target
.download_segment(platform)
.ok_or(UnsupportedPlatformError())?;
.ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?;
let download_url = format!(
"{}/api/versions/{}/{}/{}",
update_endpoint,
@ -113,7 +111,7 @@ impl UpdateService {
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
let download_segment = target
.download_segment(platform)
.ok_or(UnsupportedPlatformError())?;
.ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?;
let download_url = format!(
"{}/api/latest/{}/{}",
update_endpoint,
@ -150,7 +148,7 @@ impl UpdateService {
let download_segment = release
.target
.download_segment(release.platform)
.ok_or(UnsupportedPlatformError())?;
.ok_or_else(|| CodeError::UnsupportedPlatform(release.platform.to_string()))?;
let download_url = format!(
"{}/commit:{}/{}/{}",
@ -208,7 +206,7 @@ impl TargetKind {
}
}
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Platform {
LinuxAlpineX64,
LinuxAlpineARM64,
@ -306,3 +304,20 @@ impl Platform {
}
}
}
impl fmt::Display for Platform {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match self {
Platform::LinuxAlpineARM64 => "LinuxAlpineARM64",
Platform::LinuxAlpineX64 => "LinuxAlpineX64",
Platform::LinuxX64 => "LinuxX64",
Platform::LinuxARM64 => "LinuxARM64",
Platform::LinuxARM32 => "LinuxARM32",
Platform::DarwinX64 => "DarwinX64",
Platform::DarwinARM64 => "DarwinARM64",
Platform::WindowsX64 => "WindowsX64",
Platform::WindowsX86 => "WindowsX86",
Platform::WindowsARM64 => "WindowsARM64",
})
}
}

View file

@ -2,29 +2,47 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use super::errors::{wrap, AnyError, CommandFailed, WrappedError};
use std::{borrow::Cow, ffi::OsStr, process::Stdio};
use super::errors::CodeError;
use std::{
borrow::Cow,
ffi::OsStr,
process::{Output, Stdio},
};
use tokio::process::Command;
pub async fn capture_command_and_check_status(
command_str: impl AsRef<OsStr>,
args: &[impl AsRef<OsStr>],
) -> Result<std::process::Output, AnyError> {
) -> Result<std::process::Output, CodeError> {
let output = capture_command(&command_str, args).await?;
check_output_status(output, || {
format!(
"{} {}",
command_str.as_ref().to_string_lossy(),
args.iter()
.map(|a| a.as_ref().to_string_lossy())
.collect::<Vec<Cow<'_, str>>>()
.join(" ")
)
})
}
pub fn check_output_status(
output: Output,
cmd_str: impl FnOnce() -> String,
) -> Result<std::process::Output, CodeError> {
if !output.status.success() {
return Err(CommandFailed {
command: format!(
"{} {}",
command_str.as_ref().to_string_lossy(),
args.iter()
.map(|a| a.as_ref().to_string_lossy())
.collect::<Vec<Cow<'_, str>>>()
.join(" ")
),
output,
}
.into());
return Err(CodeError::CommandFailed {
command: cmd_str(),
code: output.status.code().unwrap_or(-1),
output: String::from_utf8_lossy(if output.stderr.is_empty() {
&output.stdout
} else {
&output.stderr
})
.into(),
});
}
Ok(output)
@ -33,7 +51,7 @@ pub async fn capture_command_and_check_status(
pub async fn capture_command<A, I, S>(
command_str: A,
args: I,
) -> Result<std::process::Output, WrappedError>
) -> Result<std::process::Output, CodeError>
where
A: AsRef<OsStr>,
I: IntoIterator<Item = S>,
@ -45,27 +63,23 @@ where
.stdout(Stdio::piped())
.output()
.await
.map_err(|e| {
wrap(
e,
format!(
"failed to execute command '{}'",
command_str.as_ref().to_string_lossy()
),
)
.map_err(|e| CodeError::CommandFailed {
command: command_str.as_ref().to_string_lossy().to_string(),
code: -1,
output: e.to_string(),
})
}
/// Kills and processes and all of its children.
#[cfg(target_os = "windows")]
pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> {
capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?;
Ok(())
}
/// Kills and processes and all of its children.
#[cfg(not(target_os = "windows"))]
pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> {
use futures::future::join_all;
use tokio::io::{AsyncBufReadExt, BufReader};
@ -82,7 +96,11 @@ pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
.stdin(Stdio::null())
.stdout(Stdio::piped())
.spawn()
.map_err(|e| wrap(e, "error enumerating process tree"))?;
.map_err(|e| CodeError::CommandFailed {
command: format!("pgrep -P {}", parent_id),
code: -1,
output: e.to_string(),
})?;
let mut kill_futures = vec![tokio::spawn(
async move { kill_single_pid(parent_id).await },

View file

@ -258,18 +258,6 @@ impl std::fmt::Display for RefreshTokenNotAvailableError {
}
}
#[derive(Debug)]
pub struct UnsupportedPlatformError();
impl std::fmt::Display for UnsupportedPlatformError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"This operation is not supported on your current platform"
)
}
}
#[derive(Debug)]
pub struct NoInstallInUserProvidedPath(pub String);
@ -419,28 +407,6 @@ impl std::fmt::Display for OAuthError {
}
}
#[derive(Debug)]
pub struct CommandFailed {
pub output: std::process::Output,
pub command: String,
}
impl std::fmt::Display for CommandFailed {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"Failed to run command \"{}\" (code {}): {}",
self.command,
self.output.status,
String::from_utf8_lossy(if self.output.stderr.is_empty() {
&self.output.stdout
} else {
&self.output.stderr
})
)
}
}
// Makes an "AnyError" enum that contains any of the given errors, in the form
// `enum AnyError { FooError(FooError) }` (when given `makeAnyError!(FooError)`).
// Useful to easily deal with application error types without making tons of "From"
@ -500,6 +466,20 @@ pub enum CodeError {
#[cfg(windows)]
#[error("could not get windows app lock: {0:?}")]
AppLockFailed(std::io::Error),
#[error("failed to run command \"{command}\" (code {code}): {output}")]
CommandFailed {
command: String,
code: i32,
output: String,
},
#[error("platform not currently supported: {0}")]
UnsupportedPlatform(String),
#[error("This machine not meet {name}'s prerequisites, expected either...: {bullets}")]
PrerequisitesFailed { name: &'static str, bullets: String },
#[error("failed to spawn process: {0:?}")]
ProcessSpawnFailed(std::io::Error)
}
makeAnyError!(
@ -518,7 +498,6 @@ makeAnyError!(
ExtensionInstallFailed,
MismatchedLaunchModeError,
NoAttachedServerError,
UnsupportedPlatformError,
RefreshTokenNotAvailableError,
NoInstallInUserProvidedPath,
UserCancelledInstallation,
@ -530,7 +509,6 @@ makeAnyError!(
UpdatesNotConfigured,
CorruptDownload,
MissingHomeDirectory,
CommandFailed,
OAuthError,
InvalidRpcDataError,
CodeError

View file

@ -16,7 +16,7 @@ use hyper::{
HeaderMap, StatusCode,
};
use serde::de::DeserializeOwned;
use std::{io, pin::Pin, str::FromStr, task::Poll};
use std::{io, pin::Pin, str::FromStr, sync::Arc, task::Poll};
use tokio::{
fs,
io::{AsyncRead, AsyncReadExt},
@ -116,6 +116,8 @@ pub trait SimpleHttp {
) -> Result<SimpleResponse, AnyError>;
}
pub type BoxedHttp = Arc<dyn SimpleHttp + Send + Sync + 'static>;
// Implementation of SimpleHttp that uses a reqwest client.
#[derive(Clone)]
pub struct ReqwestSimpleHttp {
@ -324,7 +326,6 @@ impl AsyncRead for DelegatedReader {
/// Simple http implementation that falls back to delegated http if
/// making a direct reqwest fails.
#[derive(Clone)]
pub struct FallbackSimpleHttp {
native: ReqwestSimpleHttp,
delegated: DelegatedSimpleHttp,

View file

@ -7,13 +7,12 @@ use std::cmp::Ordering;
use super::command::capture_command;
use crate::constants::QUALITYLESS_SERVER_NAME;
use crate::update_service::Platform;
use crate::util::errors::SetupError;
use lazy_static::lazy_static;
use regex::bytes::Regex as BinRegex;
use regex::Regex;
use tokio::fs;
use super::errors::AnyError;
use super::errors::CodeError;
lazy_static! {
static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap();
@ -41,19 +40,18 @@ impl PreReqChecker {
}
#[cfg(not(target_os = "linux"))]
pub async fn verify(&self) -> Result<Platform, AnyError> {
use crate::constants::QUALITYLESS_PRODUCT_NAME;
pub async fn verify(&self) -> Result<Platform, CodeError> {
Platform::env_default().ok_or_else(|| {
SetupError(format!(
"{} is not supported on this platform",
QUALITYLESS_PRODUCT_NAME
CodeError::UnsupportedPlatform(format!(
"{} {}",
std::env::consts::OS,
std::env::consts::ARCH
))
.into()
})
}
#[cfg(target_os = "linux")]
pub async fn verify(&self) -> Result<Platform, AnyError> {
pub async fn verify(&self) -> Result<Platform, CodeError> {
let (is_nixos, gnu_a, gnu_b, or_musl) = tokio::join!(
check_is_nixos(),
check_glibc_version(),
@ -96,10 +94,10 @@ impl PreReqChecker {
.collect::<Vec<String>>()
.join("\n");
Err(AnyError::from(SetupError(format!(
"This machine not meet {}'s prerequisites, expected either...\n{}",
QUALITYLESS_SERVER_NAME, bullets,
))))
Err(CodeError::PrerequisitesFailed {
bullets,
name: QUALITYLESS_SERVER_NAME,
})
}
}

View file

@ -4,9 +4,11 @@
*--------------------------------------------------------------------------------------------*/
use async_trait::async_trait;
use std::{marker::PhantomData, sync::Arc};
use tokio::sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
use tokio::{
sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
},
};
#[derive(Clone)]