diff --git a/cli/Cargo.lock b/cli/Cargo.lock index a4b541970d0..bd207ab1373 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -373,30 +373,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-deque" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" -dependencies = [ - "cfg-if", - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f916dfc5d356b0ed9dae65f1db9fc9770aa2851d2662b988ccf4fe3516e86348" -dependencies = [ - "autocfg", - "cfg-if", - "crossbeam-utils", - "memoffset", - "scopeguard", -] - [[package]] name = "crossbeam-utils" version = "0.8.12" @@ -529,12 +505,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "either" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797" - [[package]] name = "encode_unicode" version = "0.3.6" @@ -1729,30 +1699,6 @@ dependencies = [ "rand_core 0.5.1", ] -[[package]] -name = "rayon" -version = "1.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd99e5772ead8baa5215278c9b15bf92087709e9c1b2d1f97cdb5a183c933a7d" -dependencies = [ - "autocfg", - "crossbeam-deque", - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "258bcdb5ac6dad48491bb2992db6b7cf74878b0384908af124823d118c99683f" -dependencies = [ - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-utils", - "num_cpus", -] - [[package]] name = "redox_syscall" version = "0.2.16" @@ -2183,7 +2129,6 @@ dependencies = [ "libc", "ntapi", "once_cell", - "rayon", "winapi", ] diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 6b5c8d07c3f..ac05391f654 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -22,7 +22,7 @@ flate2 = { version = "1.0.22" } zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] } regex = { version = "1.5.5" } lazy_static = { version = "1.4.0" } -sysinfo = { version = "0.27.7" } +sysinfo = { version = "0.27.7", default-features = false } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } rmp-serde = "1.0" diff --git a/cli/src/bin/code/main.rs b/cli/src/bin/code/main.rs index 54660c935f6..2dd31707466 100644 --- a/cli/src/bin/code/main.rs +++ b/cli/src/bin/code/main.rs @@ -8,7 +8,7 @@ use std::process::Command; use clap::Parser; use cli::{ - commands::{args, internal_wsl, tunnels, update, version, CommandContext}, + commands::{args, tunnels, update, version, CommandContext}, constants::get_default_user_agent, desktop, log, state::LauncherPaths, @@ -65,9 +65,6 @@ async fn main() -> Result<(), std::convert::Infallible> { .. }) => match cmd { args::StandaloneCommands::Update(args) => update::update(context!(), args).await, - args::StandaloneCommands::Wsl(args) => match args.command { - args::WslCommands::Serve => internal_wsl::serve(context!()).await, - }, }, args::AnyCli::Standalone(args::StandaloneCli { core: c, .. }) | args::AnyCli::Integrated(args::IntegratedCli { core: c, .. }) => match c.subcommand { @@ -98,6 +95,8 @@ async fn main() -> Result<(), std::convert::Infallible> { args::VersionSubcommand::Show => version::show(context!()).await, }, + Some(args::Commands::CommandShell) => tunnels::command_shell(context!()).await, + Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand { Some(args::TunnelSubcommand::Prune) => tunnels::prune(context!()).await, Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context!()).await, diff --git a/cli/src/commands.rs b/cli/src/commands.rs index 082031af201..754729f2c04 100644 --- a/cli/src/commands.rs +++ b/cli/src/commands.rs @@ -6,7 +6,6 @@ mod context; pub mod args; -pub mod internal_wsl; pub mod tunnels; pub mod update; pub mod version; diff --git a/cli/src/commands/args.rs b/cli/src/commands/args.rs index 8687819f893..1cc557af3d7 100644 --- a/cli/src/commands/args.rs +++ b/cli/src/commands/args.rs @@ -146,22 +146,6 @@ impl<'a> From<&'a CliCore> for CodeServerArgs { pub enum StandaloneCommands { /// Updates the CLI. Update(StandaloneUpdateArgs), - - /// Internal commands for WSL serving. - #[clap(hide = true)] - Wsl(WslArgs), -} - -#[derive(Args, Debug, Clone)] -pub struct WslArgs { - #[clap(subcommand)] - pub command: WslCommands, -} - -#[derive(Subcommand, Debug, Clone)] -pub enum WslCommands { - /// Runs the WSL server on stdin/out - Serve, } #[derive(Args, Debug, Clone)] @@ -187,6 +171,10 @@ pub enum Commands { /// Changes the version of the editor you're using. Version(VersionArgs), + + /// Runs the control server on process stdin/stdout + #[clap(hide = true)] + CommandShell, } #[derive(Args, Debug, Clone)] diff --git a/cli/src/commands/internal_wsl.rs b/cli/src/commands/internal_wsl.rs deleted file mode 100644 index 483ee52c6aa..00000000000 --- a/cli/src/commands/internal_wsl.rs +++ /dev/null @@ -1,32 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -use crate::{ - tunnels::{serve_wsl, shutdown_signal::ShutdownRequest}, - util::{errors::AnyError, prereqs::PreReqChecker}, -}; - -use super::CommandContext; - -pub async fn serve(ctx: CommandContext) -> Result { - let signal = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]); - let platform = spanf!( - ctx.log, - ctx.log.span("prereq"), - PreReqChecker::new().verify() - )?; - - serve_wsl( - ctx.log, - ctx.paths, - (&ctx.args).into(), - platform, - ctx.http, - signal, - ) - .await?; - - Ok(0) -} diff --git a/cli/src/commands/tunnels.rs b/cli/src/commands/tunnels.rs index 56db02c58ed..cc0249224be 100644 --- a/cli/src/commands/tunnels.rs +++ b/cli/src/commands/tunnels.rs @@ -25,13 +25,13 @@ use crate::{ code_server::CodeServerArgs, create_service_manager, dev_tunnels, legal, paths::get_all_servers, - protocol, + protocol, serve_stream, shutdown_signal::ShutdownRequest, singleton_client::do_single_rpc_call, singleton_server::{ make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs, }, - Next, ServiceContainer, ServiceManager, + Next, ServeStreamParams, ServiceContainer, ServiceManager, }, util::{ app_lock::AppMutex, @@ -107,6 +107,25 @@ impl ServiceContainer for TunnelServiceContainer { } } +pub async fn command_shell(ctx: CommandContext) -> Result { + let platform = PreReqChecker::new().verify().await?; + serve_stream( + tokio::io::stdin(), + tokio::io::stderr(), + ServeStreamParams { + log: ctx.log, + launcher_paths: ctx.paths, + platform, + requires_auth: true, + exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), + code_server_args: (&ctx.args).into(), + }, + ) + .await; + + Ok(0) +} + pub async fn service( ctx: CommandContext, service_args: TunnelServiceSubCommands, diff --git a/cli/src/constants.rs b/cli/src/constants.rs index fa419e11568..2dac5d43563 100644 --- a/cli/src/constants.rs +++ b/cli/src/constants.rs @@ -18,7 +18,8 @@ pub const CONTROL_PORT: u16 = 31545; /// 2 - Addition of `serve.compressed` property to control whether servermsg's /// are compressed bidirectionally. /// 3 - The server's connection token is set to a SHA256 hash of the tunnel ID -pub const PROTOCOL_VERSION: u32 = 3; +/// 4 - The server's msgpack messages are no longer length-prefixed +pub const PROTOCOL_VERSION: u32 = 4; /// Prefix for the tunnel tag that includes the version. pub const PROTOCOL_VERSION_TAG_PREFIX: &str = "protocolv"; diff --git a/cli/src/msgpack_rpc.rs b/cli/src/msgpack_rpc.rs index 0350c1bfd64..219c923cdf2 100644 --- a/cli/src/msgpack_rpc.rs +++ b/cli/src/msgpack_rpc.rs @@ -4,8 +4,9 @@ *--------------------------------------------------------------------------------------------*/ use bytes::Buf; +use serde::de::DeserializeOwned; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, pin, sync::mpsc, }; @@ -18,7 +19,7 @@ use crate::{ sync::{Barrier, Receivable}, }, }; -use std::io; +use std::io::{self, Cursor, ErrorKind}; #[derive(Copy, Clone)] pub struct MsgPackSerializer {} @@ -35,21 +36,28 @@ impl Serialization for MsgPackSerializer { pub type MsgPackCaller = rpc::RpcCaller; -/// Creates a new RPC Builder that serializes to JSON. +/// Creates a new RPC Builder that serializes to msgpack. pub fn new_msgpack_rpc() -> rpc::RpcBuilder { rpc::RpcBuilder::new(MsgPackSerializer {}) } -pub async fn start_msgpack_rpc( - dispatcher: rpc::RpcDispatcher, - read: impl AsyncRead + Unpin, - mut write: impl AsyncWrite + Unpin, +/// Starting processing msgpack rpc over the given i/o. It's recommended that +/// the reader be passed in as a BufReader for efficiency. +pub async fn start_msgpack_rpc< + C: Send + Sync + 'static, + X: Clone, + S: Send + Sync + Serialization, + Read: AsyncRead + Unpin, + Write: AsyncWrite + Unpin, +>( + dispatcher: rpc::RpcDispatcher, + mut read: Read, + mut write: Write, mut msg_rx: impl Receivable>, - mut shutdown_rx: Barrier, -) -> io::Result> { + mut shutdown_rx: Barrier, +) -> io::Result<(Option, Read, Write)> { let (write_tx, mut write_rx) = mpsc::channel::>(8); - let mut read = BufReader::new(read); - let mut decoder = U32PrefixedCodec {}; + let mut decoder = MsgPackCodec::new(); let mut decoder_buf = bytes::BytesMut::new(); let shutdown_fut = shutdown_rx.wait(); @@ -61,7 +69,7 @@ pub async fn start_msgpack_rpc( r?; while let Some(frame) = decoder.decode(&mut decoder_buf)? { - match dispatcher.dispatch(&frame) { + match dispatcher.dispatch_with_partial(&frame.vec, frame.obj) { MaybeSync::Sync(Some(v)) => { let _ = write_tx.send(v).await; }, @@ -94,39 +102,94 @@ pub async fn start_msgpack_rpc( Some(m) = msg_rx.recv_msg() => { write.write_all(&m).await?; }, - r = &mut shutdown_fut => return Ok(r.ok()), + r = &mut shutdown_fut => return Ok((r.ok(), read, write)), } write.flush().await?; } } -/// Reader that reads length-prefixed msgpack messages in a cancellation-safe -/// way using Tokio's codecs. -pub struct U32PrefixedCodec {} +/// Reader that reads msgpack object messages in a cancellation-safe way using Tokio's codecs. +/// +/// rmp_serde does not support async reads, and does not plan to. But we know every +/// type in protocol is some kind of object, so by asking to deserialize the +/// requested object from a reader (repeatedly, if incomplete) we can +/// accomplish streaming. +pub struct MsgPackCodec { + _marker: std::marker::PhantomData, +} -const U32_SIZE: usize = 4; +impl MsgPackCodec { + pub fn new() -> Self { + Self { + _marker: std::marker::PhantomData::default(), + } + } +} -impl tokio_util::codec::Decoder for U32PrefixedCodec { - type Item = Vec; +pub struct MsgPackDecoded { + pub obj: T, + pub vec: Vec, +} + +impl tokio_util::codec::Decoder for MsgPackCodec { + type Item = MsgPackDecoded; type Error = io::Error; fn decode(&mut self, src: &mut bytes::BytesMut) -> Result, Self::Error> { - if src.len() < 4 { - src.reserve(U32_SIZE - src.len()); - return Ok(None); - } + let bytes_ref = src.as_ref(); + let mut cursor = Cursor::new(bytes_ref); - let mut be_bytes = [0; U32_SIZE]; - be_bytes.copy_from_slice(&src[..U32_SIZE]); - let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize); - if src.len() < required_len { - src.reserve(required_len - src.len()); - return Ok(None); + match rmp_serde::decode::from_read::<_, T>(&mut cursor) { + Err( + rmp_serde::decode::Error::InvalidDataRead(e) + | rmp_serde::decode::Error::InvalidMarkerRead(e), + ) if e.kind() == ErrorKind::UnexpectedEof => { + src.reserve(1024); + Ok(None) + } + Err(e) => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + e.to_string(), + )), + Ok(obj) => { + let len = cursor.position() as usize; + let vec = src[..len].to_vec(); + src.advance(len); + Ok(Some(MsgPackDecoded { obj, vec })) + } } - - let msg = src[U32_SIZE..required_len].to_vec(); - src.advance(required_len); - Ok(Some(msg)) + } +} + +#[cfg(test)] +mod tests { + use serde::{Deserialize, Serialize}; + + use super::*; + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + pub struct Msg { + pub x: i32, + } + + #[test] + fn test_protocol() { + let mut c = MsgPackCodec::::new(); + let mut buf = bytes::BytesMut::new(); + + assert!(c.decode(&mut buf).unwrap().is_none()); + + buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 1 }).unwrap().as_slice()); + buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 2 }).unwrap().as_slice()); + + assert_eq!( + c.decode(&mut buf).unwrap().expect("expected msg1").obj, + Msg { x: 1 } + ); + assert_eq!( + c.decode(&mut buf).unwrap().expect("expected msg1").obj, + Msg { x: 2 } + ); } } diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index 28dfc0efb47..02cbac61bee 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -104,6 +104,10 @@ impl RpcMethodBuilder { R: Serialize, F: Fn(P, &C) -> Result + Send + Sync + 'static, { + if self.methods.contains_key(method_name) { + panic!("Method already registered: {}", method_name); + } + let serial = self.serializer.clone(); let context = self.context.clone(); self.methods.insert( @@ -276,7 +280,9 @@ impl RpcMethodBuilder { self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| { let s1 = s1.clone(); async move { - s1.lock().await.remove(&m.stream); + if let Some(mut s) = s1.lock().await.remove(&m.stream) { + let _ = s.shutdown().await; + } Ok(()) } }); @@ -410,13 +416,17 @@ impl RpcDispatcher { /// The future or return result will be optional bytes that should be sent /// back to the socket. pub fn dispatch(&self, body: &[u8]) -> MaybeSync { - let partial = match self.serializer.deserialize::(body) { - Ok(b) => b, + match self.serializer.deserialize::(body) { + Ok(partial) => self.dispatch_with_partial(body, partial), Err(_err) => { warning!(self.log, "Failed to deserialize request, hex: {:X?}", body); - return MaybeSync::Sync(None); + MaybeSync::Sync(None) } - }; + } + } + + /// Like dispatch, but allows passing an existing PartialIncoming. + pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync { let id = partial.id; if let Some(method_name) = partial.method { @@ -536,8 +546,8 @@ trait AssertIsSync: Sync {} impl AssertIsSync for RpcDispatcher {} /// Approximate shape that is used to determine what kind of data is incoming. -#[derive(Deserialize)] -struct PartialIncoming { +#[derive(Deserialize, Debug)] +pub struct PartialIncoming { pub id: Option, pub method: Option, pub error: Option, diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index ebab9475988..801b6545e51 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -7,11 +7,12 @@ pub mod code_server; pub mod dev_tunnels; pub mod legal; pub mod paths; +pub mod protocol; pub mod shutdown_signal; pub mod singleton_client; pub mod singleton_server; -pub mod protocol; +mod challenge; mod control_server; mod nosleep; #[cfg(target_os = "linux")] @@ -31,11 +32,9 @@ mod service_macos; #[cfg(target_os = "windows")] mod service_windows; mod socket_signal; -mod wsl_server; -pub use control_server::{serve, Next}; +pub use control_server::{serve, serve_stream, Next, ServeStreamParams}; pub use nosleep::SleepInhibitor; pub use service::{ create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME, }; -pub use wsl_server::serve_wsl; diff --git a/cli/src/tunnels/challenge.rs b/cli/src/tunnels/challenge.rs new file mode 100644 index 00000000000..1c4abc651ba --- /dev/null +++ b/cli/src/tunnels/challenge.rs @@ -0,0 +1,41 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +#[cfg(not(feature = "vsda"))] +pub fn create_challenge() -> String { + use rand::distributions::{Alphanumeric, DistString}; + Alphanumeric.sample_string(&mut rand::thread_rng(), 16) +} + +#[cfg(not(feature = "vsda"))] +pub fn sign_challenge(challenge: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hash = Sha256::new(); + hash.update(challenge.as_bytes()); + let result = hash.finalize(); + base64::encode_config(result, base64::URL_SAFE_NO_PAD) +} + +#[cfg(not(feature = "vsda"))] +pub fn verify_challenge(challenge: &str, response: &str) -> bool { + sign_challenge(challenge) == response +} + +#[cfg(feature = "vsda")] +pub fn create_challenge() -> String { + use rand::distributions::{Alphanumeric, DistString}; + let str = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + vsda::create_new_message(&str) +} + +#[cfg(feature = "vsda")] +pub fn sign_challenge(challenge: &str) -> String { + vsda::sign(challenge) +} + +#[cfg(feature = "vsda")] +pub fn verify_challenge(challenge: &str, response: &str) -> bool { + vsda::validate(challenge, response) +} diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index 67a0bcf64ac..bf85f1b28bb 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -5,16 +5,15 @@ use crate::async_pipe::get_socket_rw_stream; use crate::constants::{CONTROL_PORT, PRODUCT_NAME_LONG}; use crate::log; -use crate::msgpack_rpc::U32PrefixedCodec; -use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization}; +use crate::msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCodec, MsgPackSerializer}; +use crate::rpc::{MaybeSync, RpcBuilder, RpcCaller, RpcDispatcher}; use crate::self_update::SelfUpdate; use crate::state::LauncherPaths; -use crate::tunnels::protocol::HttpRequestParams; +use crate::tunnels::protocol::{HttpRequestParams, METHOD_CHALLENGE_ISSUE}; use crate::tunnels::socket_signal::CloseReason; use crate::update_service::{Platform, Release, TargetKind, UpdateService}; use crate::util::errors::{ - wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError, - NoAttachedServerError, + wrap, AnyError, CodeError, MismatchedLaunchModeError, NoAttachedServerError, }; use crate::util::http::{ DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp, @@ -22,7 +21,7 @@ use crate::util::http::{ use crate::util::io::SilentCopyProgress; use crate::util::is_integrated_cli; use crate::util::os::os_release; -use crate::util::sync::{new_barrier, Barrier}; +use crate::util::sync::{new_barrier, Barrier, BarrierOpener}; use futures::stream::FuturesUnordered; use futures::FutureExt; @@ -31,6 +30,7 @@ use opentelemetry::KeyValue; use std::collections::HashMap; use std::process::Stdio; use tokio::pin; +use tokio::process::{ChildStderr, ChildStdin}; use tokio_util::codec::Decoder; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; @@ -39,6 +39,7 @@ use std::time::Instant; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream}; use tokio::sync::{mpsc, Mutex}; +use super::challenge::{create_challenge, sign_challenge, verify_challenge}; use super::code_server::{ download_cli_into_cache, AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw, SocketCodeServer, @@ -47,11 +48,12 @@ use super::dev_tunnels::ActiveTunnel; use super::paths::prune_stopped_servers; use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ - AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, - ForwardParams, ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, - GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, - ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, - UpdateResult, VersionParams, + AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse, + ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult, + FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams, + HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult, + ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse, + METHOD_CHALLENGE_VERIFY, }; use super::server_bridge::ServerBridge; use super::server_multiplexer::ServerMultiplexer; @@ -68,6 +70,8 @@ struct HandlerContext { log: log::Logger, /// Whether the server update during the handler session. did_update: Arc, + /// Whether authentication is still required on the socket. + auth_state: Arc>, /// A loopback channel to talk to the socket server task. socket_tx: mpsc::Sender, /// Configured launcher paths. @@ -79,7 +83,7 @@ struct HandlerContext { // the cli arguments used to start the code server code_server_args: CodeServerArgs, /// port forwarding functionality - port_forwarding: PortForwarding, + port_forwarding: Option, /// install platform for the VS Code server platform: Platform, /// http client to make download/update requests @@ -88,6 +92,16 @@ struct HandlerContext { http_requests: HttpRequestsMap, } +/// Handler auth state. +enum AuthState { + /// Auth is required, we're waiting for the client to send its challenge. + WaitingForChallenge, + /// A challenge has been issued. Waiting for a verification. + ChallengeIssued(String), + /// Auth is no longer required. + Authenticated, +} + static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0); // Gets a next incrementing number that can be used in logs @@ -195,7 +209,14 @@ pub async fn serve( debug!(own_log, "Serving new connection"); let (writehalf, readhalf) = socket.into_split(); - let stats = process_socket(own_exit, readhalf, writehalf, own_log, own_tx, own_paths, own_code_server_args, own_forwarding, platform).with_context(cx.clone()).await; + let stats = process_socket(readhalf, writehalf, own_tx, Some(own_forwarding), ServeStreamParams { + log: own_log, + launcher_paths: own_paths, + code_server_args: own_code_server_args, + platform, + exit_barrier: own_exit, + requires_auth: false, + }).with_context(cx.clone()).await; cx.span().add_event( "socket.bandwidth", @@ -206,69 +227,91 @@ pub async fn serve( ], ); cx.span().end(); - }); + }); } } } } -struct SocketStats { +pub struct ServeStreamParams { + pub log: log::Logger, + pub launcher_paths: LauncherPaths, + pub code_server_args: CodeServerArgs, + pub platform: Platform, + pub requires_auth: bool, + pub exit_barrier: Barrier, +} + +pub async fn serve_stream( + readhalf: impl AsyncRead + Send + Unpin + 'static, + writehalf: impl AsyncWrite + Unpin, + params: ServeStreamParams, +) -> SocketStats { + // Currently the only server signal is respawn, that doesn't have much meaning + // when serving a stream, so make an ignored channel. + let (server_rx, server_tx) = mpsc::channel(1); + drop(server_tx); + + process_socket(readhalf, writehalf, server_rx, None, params).await +} + +pub struct SocketStats { rx: usize, tx: usize, } -#[derive(Copy, Clone)] -struct MsgPackSerializer {} - -impl Serialization for MsgPackSerializer { - fn serialize(&self, value: impl serde::Serialize) -> Vec { - rmp_serde::to_vec_named(&value).expect("expected to serialize") - } - - fn deserialize(&self, b: &[u8]) -> Result { - rmp_serde::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into()) - } -} - -#[allow(clippy::too_many_arguments)] // necessary here -async fn process_socket( - mut exit_barrier: Barrier<()>, - readhalf: impl AsyncRead + Send + Unpin + 'static, - mut writehalf: impl AsyncWrite + Unpin, +#[allow(clippy::too_many_arguments)] +fn make_socket_rpc( log: log::Logger, - server_tx: mpsc::Sender, + socket_tx: mpsc::Sender, + http_delegated: DelegatedSimpleHttp, launcher_paths: LauncherPaths, code_server_args: CodeServerArgs, - port_forwarding: PortForwarding, + port_forwarding: Option, + requires_auth: bool, platform: Platform, -) -> SocketStats { - let (socket_tx, mut socket_rx) = mpsc::channel(4); - let rx_counter = Arc::new(AtomicUsize::new(0)); +) -> RpcDispatcher { let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new())); let server_bridges = ServerMultiplexer::new(); - let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone()); let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext { did_update: Arc::new(AtomicBool::new(false)), - socket_tx: socket_tx.clone(), + auth_state: Arc::new(std::sync::Mutex::new(match requires_auth { + true => AuthState::WaitingForChallenge, + false => AuthState::Authenticated, + })), + socket_tx, log: log.clone(), launcher_paths, code_server_args, code_server: Arc::new(Mutex::new(None)), - server_bridges: server_bridges.clone(), + server_bridges, port_forwarding, platform, http: Arc::new(FallbackSimpleHttp::new( ReqwestSimpleHttp::new(), http_delegated, )), - http_requests: http_requests.clone(), + http_requests, }); rpc.register_sync("ping", |_: EmptyObject, _| Ok(EmptyObject {})); rpc.register_sync("gethostname", |_: EmptyObject, _| handle_get_hostname()); - rpc.register_sync("fs_stat", |p: FsStatRequest, _| handle_stat(p.path)); - rpc.register_sync("get_env", |_: EmptyObject, _| handle_get_env()); + rpc.register_sync("fs_stat", |p: FsStatRequest, c| { + ensure_auth(&c.auth_state)?; + handle_stat(p.path) + }); + rpc.register_sync("get_env", |_: EmptyObject, c| { + ensure_auth(&c.auth_state)?; + handle_get_env() + }); + rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| { + handle_challenge_issue(&c.auth_state) + }); + rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| { + handle_challenge_verify(p.response, &c.auth_state) + }); rpc.register_async("serve", move |params: ServeParams, c| async move { + ensure_auth(&c.auth_state)?; handle_serve(c, params).await }); rpc.register_async("update", |p: UpdateParams, c| async move { @@ -286,15 +329,19 @@ async fn process_socket( handle_call_server_http(code_server, p).await }); rpc.register_async("forward", |p: ForwardParams, c| async move { + ensure_auth(&c.auth_state)?; handle_forward(&c.log, &c.port_forwarding, p).await }); rpc.register_async("unforward", |p: UnforwardParams, c| async move { + ensure_auth(&c.auth_state)?; handle_unforward(&c.log, &c.port_forwarding, p).await }); rpc.register_async("acquire_cli", |p: AcquireCliParams, c| async move { + ensure_auth(&c.auth_state)?; handle_acquire_cli(&c.launcher_paths, &c.http, &c.log, p).await }); rpc.register_duplex("spawn", 3, |mut streams, p: SpawnParams, c| async move { + ensure_auth(&c.auth_state)?; handle_spawn( &c.log, p, @@ -304,13 +351,28 @@ async fn process_socket( ) .await }); + rpc.register_duplex( + "spawn_cli", + 3, + |mut streams, p: SpawnParams, c| async move { + ensure_auth(&c.auth_state)?; + handle_spawn_cli( + &c.log, + p, + streams.remove(0), + streams.remove(0), + streams.remove(0), + ) + .await + }, + ); rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| { if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) { req.initial_response(p.status_code, p.headers); } Ok(EmptyObject {}) }); - rpc.register_sync("unforward", move |p: HttpBodyParams, c| { + rpc.register_sync("httpbody", move |p: HttpBodyParams, c| { let mut reqs = c.http_requests.lock().unwrap(); if let Some(req) = reqs.get(&p.req_id) { if !p.segment.is_empty() { @@ -322,15 +384,64 @@ async fn process_socket( } Ok(EmptyObject {}) }); + rpc.register_sync( + "version", + |_: EmptyObject, _| Ok(VersionResponse::default()), + ); + + rpc.build(log) +} + +fn ensure_auth(is_authed: &Arc>) -> Result<(), AnyError> { + if let AuthState::Authenticated = &*is_authed.lock().unwrap() { + Ok(()) + } else { + Err(CodeError::ServerAuthRequired.into()) + } +} + +#[allow(clippy::too_many_arguments)] // necessary here +async fn process_socket( + readhalf: impl AsyncRead + Send + Unpin + 'static, + mut writehalf: impl AsyncWrite + Unpin, + server_tx: mpsc::Sender, + port_forwarding: Option, + params: ServeStreamParams, +) -> SocketStats { + let ServeStreamParams { + mut exit_barrier, + log, + launcher_paths, + code_server_args, + platform, + requires_auth, + } = params; + + let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone()); + let (socket_tx, mut socket_rx) = mpsc::channel(4); + let rx_counter = Arc::new(AtomicUsize::new(0)); + let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new())); + + let rpc = make_socket_rpc( + log.clone(), + socket_tx.clone(), + http_delegated, + launcher_paths, + code_server_args, + port_forwarding, + requires_auth, + platform, + ); { let log = log.clone(); let rx_counter = rx_counter.clone(); let socket_tx = socket_tx.clone(); let exit_barrier = exit_barrier.clone(); - let rpc = rpc.build(log.clone()); tokio::spawn(async move { - send_version(&socket_tx).await; + if !requires_auth { + send_version(&socket_tx).await; + } if let Err(e) = handle_socket_read(&log, readhalf, exit_barrier, &socket_tx, rx_counter, &rpc).await @@ -350,6 +461,10 @@ async fn process_socket( } ctx.dispose().await; + + let _ = socket_tx + .send(SocketSignal::CloseWith(CloseReason("eof".to_string()))) + .await; }); } @@ -408,7 +523,7 @@ async fn process_socket( async fn send_version(tx: &mpsc::Sender) { tx.send(SocketSignal::from_message(&ToClientRequest { id: None, - params: ClientRequestMethod::version(VersionParams::default()), + params: ClientRequestMethod::version(VersionResponse::default()), })) .await .ok(); @@ -416,13 +531,13 @@ async fn send_version(tx: &mpsc::Sender) { async fn handle_socket_read( _log: &log::Logger, readhalf: impl AsyncRead + Unpin, - mut closer: Barrier<()>, + mut closer: Barrier, socket_tx: &mpsc::Sender, rx_counter: Arc, rpc: &RpcDispatcher, ) -> Result<(), std::io::Error> { let mut readhalf = BufReader::new(readhalf); - let mut decoder = U32PrefixedCodec {}; + let mut decoder = MsgPackCodec::new(); let mut decoder_buf = bytes::BytesMut::new(); loop { @@ -431,10 +546,14 @@ async fn handle_socket_read( _ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")), }?; + if read_len == 0 { + return Ok(()); + } + rx_counter.fetch_add(read_len, Ordering::Relaxed); while let Some(frame) = decoder.decode(&mut decoder_buf)? { - match rpc.dispatch(&frame) { + match rpc.dispatch_with_partial(&frame.vec, frame.obj) { MaybeSync::Sync(Some(v)) => { if socket_tx.send(SocketSignal::Send(v)).await.is_err() { return Ok(()); @@ -704,11 +823,44 @@ fn handle_get_env() -> Result { }) } +fn handle_challenge_issue( + auth_state: &Arc>, +) -> Result { + let challenge = create_challenge(); + + let mut auth_state = auth_state.lock().unwrap(); + *auth_state = AuthState::ChallengeIssued(challenge.clone()); + + Ok(ChallengeIssueResponse { challenge }) +} + +fn handle_challenge_verify( + response: String, + auth_state: &Arc>, +) -> Result { + let mut auth_state = auth_state.lock().unwrap(); + + match &*auth_state { + AuthState::Authenticated => Ok(EmptyObject {}), + AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()), + AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) { + false => Err(CodeError::AuthChallengeNotIssued.into()), + true => { + *auth_state = AuthState::Authenticated; + Ok(EmptyObject {}) + } + }, + } +} + async fn handle_forward( log: &log::Logger, - port_forwarding: &PortForwarding, + port_forwarding: &Option, params: ForwardParams, ) -> Result { + let port_forwarding = port_forwarding + .as_ref() + .ok_or(CodeError::PortForwardingNotAvailable)?; info!(log, "Forwarding port {}", params.port); let uri = port_forwarding.forward(params.port).await?; Ok(ForwardResult { uri }) @@ -716,9 +868,12 @@ async fn handle_forward( async fn handle_unforward( log: &log::Logger, - port_forwarding: &PortForwarding, + port_forwarding: &Option, params: UnforwardParams, ) -> Result { + let port_forwarding = port_forwarding + .as_ref() + .ok_or(CodeError::PortForwardingNotAvailable)?; info!(log, "Unforwarding port {}", params.port); port_forwarding.unforward(params.port).await?; Ok(EmptyObject {}) @@ -818,17 +973,17 @@ async fn handle_spawn( stderr: Option, ) -> Result where - Stdin: AsyncRead + Unpin + Send, - StdoutAndErr: AsyncWrite + Unpin + Send, + Stdin: AsyncRead + Unpin + Send + 'static, + StdoutAndErr: AsyncWrite + Unpin + Send + 'static, { debug!( log, "requested to spawn {} with args {:?}", params.command, params.args ); - macro_rules! pipe_if_some { + macro_rules! pipe_if { ($e: expr) => { - if $e.is_some() { + if $e { Stdio::piped() } else { Stdio::null() @@ -839,9 +994,9 @@ where let mut p = tokio::process::Command::new(¶ms.command); p.args(¶ms.args); p.envs(¶ms.env); - p.stdin(pipe_if_some!(stdin)); - p.stdout(pipe_if_some!(stdout)); - p.stderr(pipe_if_some!(stderr)); + p.stdin(pipe_if!(stdin.is_some())); + p.stdout(pipe_if!(stdin.is_some())); + p.stderr(pipe_if!(stderr.is_some())); if let Some(cwd) = ¶ms.cwd { p.current_dir(cwd); } @@ -859,7 +1014,72 @@ where futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed()); } - let closed = p.wait(); + wait_for_process_exit(log, ¶ms.command, p, futs).await +} + +async fn handle_spawn_cli( + log: &log::Logger, + params: SpawnParams, + mut protocol_in: DuplexStream, + mut protocol_out: DuplexStream, + mut log_out: DuplexStream, +) -> Result { + debug!( + log, + "requested to spawn cli {} with args {:?}", params.command, params.args + ); + + let mut p = tokio::process::Command::new(¶ms.command); + p.args(¶ms.args); + + // CLI args to spawn a server; contracted with clients that they should _not_ provide these. + p.arg("--verbose"); + p.arg("tunnel"); + p.arg("stdio"); + + p.envs(¶ms.env); + p.stdin(Stdio::piped()); + p.stdout(Stdio::piped()); + p.stderr(Stdio::piped()); + if let Some(cwd) = ¶ms.cwd { + p.current_dir(cwd); + } + + let mut p = p.spawn().map_err(CodeError::ProcessSpawnFailed)?; + + let mut stdin = p.stdin.take().unwrap(); + let mut stdout = p.stdout.take().unwrap(); + let mut stderr = p.stderr.take().unwrap(); + + // Start handling logs while doing the handshake in case there's some kind of error + let log_pump = tokio::spawn(async move { tokio::io::copy(&mut stdout, &mut log_out).await }); + + // note: intentionally do not wrap stdin in a bufreader, since we don't + // want to read anything other than our handshake messages. + if let Err(e) = spawn_do_child_authentication(log, &mut stdin, &mut stderr).await { + warning!(log, "failed to authenticate with child process {}", e); + let _ = p.kill().await; + return Err(e.into()); + } + + debug!(log, "cli authenticated, attaching stdio"); + let futs = FuturesUnordered::new(); + futs.push(async move { tokio::io::copy(&mut protocol_in, &mut stdin).await }.boxed()); + futs.push(async move { tokio::io::copy(&mut stderr, &mut protocol_out).await }.boxed()); + futs.push(async move { log_pump.await.unwrap() }.boxed()); + + wait_for_process_exit(log, ¶ms.command, p, futs).await +} + +type TokioCopyFuture = dyn futures::Future> + Send; + +async fn wait_for_process_exit( + log: &log::Logger, + command: &str, + mut process: tokio::process::Child, + futs: FuturesUnordered>>, +) -> Result { + let closed = process.wait(); pin!(closed); let r = tokio::select! { @@ -880,8 +1100,69 @@ where debug!( log, - "spawned command {} exited with code {}", params.command, r.exit_code + "spawned cli {} exited with code {}", command, r.exit_code ); Ok(r) } + +async fn spawn_do_child_authentication( + log: &log::Logger, + stdin: &mut ChildStdin, + stdout: &mut ChildStderr, +) -> Result<(), CodeError> { + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); + let (shutdown_rx, shutdown) = new_barrier(); + let mut rpc = new_msgpack_rpc(); + let caller = rpc.get_caller(msg_tx); + + let challenge_response = do_challenge_response_flow(caller, shutdown); + let rpc = start_msgpack_rpc( + rpc.methods(()).build(log.prefixed("client-auth")), + stdout, + stdin, + msg_rx, + shutdown_rx, + ); + pin!(rpc); + + tokio::select! { + r = &mut rpc => { + match r { + // means shutdown happened cleanly already, we're good + Ok(_) => Ok(()), + Err(e) => Err(CodeError::ProcessSpawnHandshakeFailed(e)) + } + }, + r = challenge_response => { + r?; + rpc.await.map(|_| ()).map_err(CodeError::ProcessSpawnFailed) + } + } +} + +async fn do_challenge_response_flow( + caller: RpcCaller, + shutdown: BarrierOpener<()>, +) -> Result<(), CodeError> { + let challenge: ChallengeIssueResponse = caller + .call(METHOD_CHALLENGE_ISSUE, EmptyObject {}) + .await + .unwrap() + .map_err(CodeError::TunnelRpcCallFailed)?; + + let _: EmptyObject = caller + .call( + METHOD_CHALLENGE_VERIFY, + ChallengeVerifyParams { + response: sign_challenge(&challenge.challenge), + }, + ) + .await + .unwrap() + .map_err(CodeError::TunnelRpcCallFailed)?; + + shutdown.open(()); + + Ok(()) +} diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index 17282381c55..eb20afe0ce5 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -18,7 +18,7 @@ pub enum ClientRequestMethod<'a> { servermsg(RefServerMessageParams<'a>), serverlog(ServerLog<'a>), makehttpreq(HttpRequestParams<'a>), - version(VersionParams), + version(VersionResponse), } #[derive(Deserialize, Debug)] @@ -58,14 +58,6 @@ pub struct ForwardResult { pub uri: String, } -/// The `install_local` method in the wsl control server -#[derive(Deserialize, Debug)] -pub struct InstallFromLocalFolderParams { - pub archive_path: String, - #[serde(flatten)] - pub inner: ServeParams, -} - #[derive(Deserialize, Debug)] pub struct ServeParams { pub socket_id: u16, @@ -165,12 +157,12 @@ pub struct CallServerHttpResult { } #[derive(Serialize, Debug)] -pub struct VersionParams { +pub struct VersionResponse { pub version: &'static str, pub protocol_version: u32, } -impl Default for VersionParams { +impl Default for VersionResponse { fn default() -> Self { Self { version: VSCODE_CLI_VERSION.unwrap_or("dev"), @@ -204,6 +196,19 @@ pub struct SpawnResult { pub exit_code: i32, } +pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue"; +pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify"; + +#[derive(Serialize, Deserialize)] +pub struct ChallengeIssueResponse { + pub challenge: String, +} + +#[derive(Deserialize, Serialize)] +pub struct ChallengeVerifyParams { + pub response: String, +} + pub mod singleton { use crate::log; use serde::{Deserialize, Serialize}; diff --git a/cli/src/tunnels/singleton_client.rs b/cli/src/tunnels/singleton_client.rs index b67a53306d1..ef9fdf85cc0 100644 --- a/cli/src/tunnels/singleton_client.rs +++ b/cli/src/tunnels/singleton_client.rs @@ -74,7 +74,6 @@ pub async fn start_singleton_client(args: SingletonClientArgs) -> bool { let mut input = String::new(); loop { input.truncate(0); - println!("reading line"); match std::io::stdin().read_line(&mut input) { Err(_) | Ok(0) => return, // EOF or not a tty _ => {} diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs index a3d3b08a5d4..2a2df6607ea 100644 --- a/cli/src/tunnels/socket_signal.rs +++ b/cli/src/tunnels/socket_signal.rs @@ -38,6 +38,7 @@ impl SocketSignal { } /// todo@connor4312: cleanup once everything is moved to rpc standard interfaces +#[allow(dead_code)] pub enum ServerMessageDestination { Channel(mpsc::Sender), Rpc(MsgPackCaller), diff --git a/cli/src/tunnels/wsl_server.rs b/cli/src/tunnels/wsl_server.rs deleted file mode 100644 index 3eafae92f3a..00000000000 --- a/cli/src/tunnels/wsl_server.rs +++ /dev/null @@ -1,173 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -use std::sync::Arc; - -use tokio::sync::mpsc; - -use crate::{ - log, - msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCaller}, - state::LauncherPaths, - tunnels::code_server::ServerBuilder, - update_service::{Platform, Release, TargetKind}, - util::{ - errors::{ - wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError, - }, - http::ReqwestSimpleHttp, - sync::Barrier, - }, -}; - -use super::{ - code_server::{AnyCodeServer, CodeServerArgs, ResolvedServerParams}, - protocol::{EmptyObject, InstallFromLocalFolderParams, ServerMessageParams, VersionParams}, - server_bridge::ServerBridge, - server_multiplexer::ServerMultiplexer, - shutdown_signal::ShutdownSignal, - socket_signal::{ClientMessageDecoder, ServerMessageDestination, ServerMessageSink}, -}; - -struct HandlerContext { - log: log::Logger, - code_server_args: CodeServerArgs, - launcher_paths: LauncherPaths, - platform: Platform, - http: ReqwestSimpleHttp, - caller: MsgPackCaller, - multiplexer: ServerMultiplexer, -} - -#[derive(Clone)] -struct RpcLogSink(MsgPackCaller); - -impl RpcLogSink { - fn write_json(&self, level: String, message: &str) { - self.0.notify( - "log", - serde_json::json!({ - "level": level, - "message": message, - }), - ); - } -} - -impl log::LogSink for RpcLogSink { - fn write_log(&self, level: log::Level, _prefix: &str, message: &str) { - self.write_json(level.to_string(), message); - } - - fn write_result(&self, message: &str) { - self.write_json("result".to_string(), message); - } -} - -pub async fn serve_wsl( - log: log::Logger, - launcher_paths: LauncherPaths, - code_server_args: CodeServerArgs, - platform: Platform, - http: reqwest::Client, - shutdown_rx: Barrier, -) -> Result { - let (caller_tx, caller_rx) = mpsc::unbounded_channel(); - let mut rpc = new_msgpack_rpc(); - let caller = rpc.get_caller(caller_tx); - - // notify the incoming client about the server version - caller.notify("version", VersionParams::default()); - - let log = log.with_sink(RpcLogSink(caller.clone())); - let mut rpc = rpc.methods(HandlerContext { - log: log.clone(), - caller, - code_server_args, - launcher_paths, - platform, - multiplexer: ServerMultiplexer::new(), - http: ReqwestSimpleHttp::with_client(http), - }); - - rpc.register_async( - "serve", - move |m: InstallFromLocalFolderParams, c| async move { handle_serve(&c, m).await }, - ); - rpc.register_sync("servermsg", move |m: ServerMessageParams, c| { - if c.multiplexer.write_message(&c.log, m.i, m.body) { - Ok(EmptyObject {}) - } else { - Err(NoAttachedServerError().into()) - } - }); - - start_msgpack_rpc( - rpc.build(log), - tokio::io::stdin(), - tokio::io::stderr(), - caller_rx, - shutdown_rx, - ) - .await - .map_err(|e| wrap(e, "error handling server stdio"))?; - - Ok(0) -} - -async fn handle_serve( - c: &HandlerContext, - params: InstallFromLocalFolderParams, -) -> Result { - // fill params.extensions into code_server_args.install_extensions - let mut csa = c.code_server_args.clone(); - csa.connection_token = params.inner.connection_token.or(csa.connection_token); - csa.install_extensions - .extend(params.inner.extensions.into_iter()); - - let resolved = ResolvedServerParams { - code_server_args: csa, - release: Release { - name: String::new(), - commit: params - .inner - .commit_id - .ok_or_else(|| InvalidRpcDataError("commit_id is required".to_string()))?, - platform: c.platform, - target: TargetKind::Server, - quality: params.inner.quality, - }, - }; - - let sb = ServerBuilder::new( - &c.log, - &resolved, - &c.launcher_paths, - Arc::new(c.http.clone()), - ); - let code_server = match sb.get_running().await? { - Some(AnyCodeServer::Socket(s)) => s, - Some(_) => return Err(MismatchedLaunchModeError().into()), - None => { - sb.setup().await?; - sb.listen_on_default_socket().await? - } - }; - - let bridge = ServerBridge::new( - &code_server.socket, - ServerMessageSink::new_plain( - c.multiplexer.clone(), - params.inner.socket_id, - ServerMessageDestination::Rpc(c.caller.clone()), - ), - ClientMessageDecoder::new_plain(), - ) - .await?; - - c.multiplexer.register(params.inner.socket_id, bridge); - trace!(c.log, "Attached to server"); - Ok(EmptyObject {}) -} diff --git a/cli/src/util/errors.rs b/cli/src/util/errors.rs index 9ab421d3301..f1c4cbf5c22 100644 --- a/cli/src/util/errors.rs +++ b/cli/src/util/errors.rs @@ -479,9 +479,18 @@ pub enum CodeError { PrerequisitesFailed { name: &'static str, bullets: String }, #[error("failed to spawn process: {0:?}")] ProcessSpawnFailed(std::io::Error), - + #[error("failed to handshake spawned process: {0:?}")] + ProcessSpawnHandshakeFailed(std::io::Error), #[error("download appears corrupted, please retry ({0})")] CorruptDownload(&'static str), + #[error("port forwarding is not available in this context")] + PortForwardingNotAvailable, + #[error("'auth' call required")] + ServerAuthRequired, + #[error("challenge not yet issued")] + AuthChallengeNotIssued, + #[error("unauthorized client refused")] + AuthMismatch, } makeAnyError!( diff --git a/package.json b/package.json index 5ad0e0ef6c4..0b2e397d698 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "code-oss-dev", "version": "1.79.0", - "distro": "10ec4d08d4a06a1c18addcded03e90c8a0e6ecad", + "distro": "edc0b65674651d40e7bd668fbc675ddb30cec375", "author": { "name": "Microsoft Corporation" },