mirror of
https://github.com/Microsoft/vscode
synced 2024-10-05 19:02:54 +00:00
cli: add stdio control server
* signing: implement signing service on the web * wip * cli: implement stdio service This is used to implement the exec server for WSL. Guarded behind a signed handshake. * update distro * rm debug * address pr comments
This commit is contained in:
parent
c125687c4d
commit
679bb967c3
55
cli/Cargo.lock
generated
55
cli/Cargo.lock
generated
|
@ -373,30 +373,6 @@ dependencies = [
|
|||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f916dfc5d356b0ed9dae65f1db9fc9770aa2851d2662b988ccf4fe3516e86348"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"memoffset",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.12"
|
||||
|
@ -529,12 +505,6 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797"
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "0.3.6"
|
||||
|
@ -1729,30 +1699,6 @@ dependencies = [
|
|||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd99e5772ead8baa5215278c9b15bf92087709e9c1b2d1f97cdb5a183c933a7d"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"crossbeam-deque",
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "258bcdb5ac6dad48491bb2992db6b7cf74878b0384908af124823d118c99683f"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.16"
|
||||
|
@ -2183,7 +2129,6 @@ dependencies = [
|
|||
"libc",
|
||||
"ntapi",
|
||||
"once_cell",
|
||||
"rayon",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ flate2 = { version = "1.0.22" }
|
|||
zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] }
|
||||
regex = { version = "1.5.5" }
|
||||
lazy_static = { version = "1.4.0" }
|
||||
sysinfo = { version = "0.27.7" }
|
||||
sysinfo = { version = "0.27.7", default-features = false }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = { version = "1.0" }
|
||||
rmp-serde = "1.0"
|
||||
|
|
|
@ -8,7 +8,7 @@ use std::process::Command;
|
|||
|
||||
use clap::Parser;
|
||||
use cli::{
|
||||
commands::{args, internal_wsl, tunnels, update, version, CommandContext},
|
||||
commands::{args, tunnels, update, version, CommandContext},
|
||||
constants::get_default_user_agent,
|
||||
desktop, log,
|
||||
state::LauncherPaths,
|
||||
|
@ -65,9 +65,6 @@ async fn main() -> Result<(), std::convert::Infallible> {
|
|||
..
|
||||
}) => match cmd {
|
||||
args::StandaloneCommands::Update(args) => update::update(context!(), args).await,
|
||||
args::StandaloneCommands::Wsl(args) => match args.command {
|
||||
args::WslCommands::Serve => internal_wsl::serve(context!()).await,
|
||||
},
|
||||
},
|
||||
args::AnyCli::Standalone(args::StandaloneCli { core: c, .. })
|
||||
| args::AnyCli::Integrated(args::IntegratedCli { core: c, .. }) => match c.subcommand {
|
||||
|
@ -98,6 +95,8 @@ async fn main() -> Result<(), std::convert::Infallible> {
|
|||
args::VersionSubcommand::Show => version::show(context!()).await,
|
||||
},
|
||||
|
||||
Some(args::Commands::CommandShell) => tunnels::command_shell(context!()).await,
|
||||
|
||||
Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand {
|
||||
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context!()).await,
|
||||
Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context!()).await,
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
mod context;
|
||||
|
||||
pub mod args;
|
||||
pub mod internal_wsl;
|
||||
pub mod tunnels;
|
||||
pub mod update;
|
||||
pub mod version;
|
||||
|
|
|
@ -146,22 +146,6 @@ impl<'a> From<&'a CliCore> for CodeServerArgs {
|
|||
pub enum StandaloneCommands {
|
||||
/// Updates the CLI.
|
||||
Update(StandaloneUpdateArgs),
|
||||
|
||||
/// Internal commands for WSL serving.
|
||||
#[clap(hide = true)]
|
||||
Wsl(WslArgs),
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct WslArgs {
|
||||
#[clap(subcommand)]
|
||||
pub command: WslCommands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum WslCommands {
|
||||
/// Runs the WSL server on stdin/out
|
||||
Serve,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
|
@ -187,6 +171,10 @@ pub enum Commands {
|
|||
|
||||
/// Changes the version of the editor you're using.
|
||||
Version(VersionArgs),
|
||||
|
||||
/// Runs the control server on process stdin/stdout
|
||||
#[clap(hide = true)]
|
||||
CommandShell,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::{
|
||||
tunnels::{serve_wsl, shutdown_signal::ShutdownRequest},
|
||||
util::{errors::AnyError, prereqs::PreReqChecker},
|
||||
};
|
||||
|
||||
use super::CommandContext;
|
||||
|
||||
pub async fn serve(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let signal = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]);
|
||||
let platform = spanf!(
|
||||
ctx.log,
|
||||
ctx.log.span("prereq"),
|
||||
PreReqChecker::new().verify()
|
||||
)?;
|
||||
|
||||
serve_wsl(
|
||||
ctx.log,
|
||||
ctx.paths,
|
||||
(&ctx.args).into(),
|
||||
platform,
|
||||
ctx.http,
|
||||
signal,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(0)
|
||||
}
|
|
@ -25,13 +25,13 @@ use crate::{
|
|||
code_server::CodeServerArgs,
|
||||
create_service_manager, dev_tunnels, legal,
|
||||
paths::get_all_servers,
|
||||
protocol,
|
||||
protocol, serve_stream,
|
||||
shutdown_signal::ShutdownRequest,
|
||||
singleton_client::do_single_rpc_call,
|
||||
singleton_server::{
|
||||
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
|
||||
},
|
||||
Next, ServiceContainer, ServiceManager,
|
||||
Next, ServeStreamParams, ServiceContainer, ServiceManager,
|
||||
},
|
||||
util::{
|
||||
app_lock::AppMutex,
|
||||
|
@ -107,6 +107,25 @@ impl ServiceContainer for TunnelServiceContainer {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let platform = PreReqChecker::new().verify().await?;
|
||||
serve_stream(
|
||||
tokio::io::stdin(),
|
||||
tokio::io::stderr(),
|
||||
ServeStreamParams {
|
||||
log: ctx.log,
|
||||
launcher_paths: ctx.paths,
|
||||
platform,
|
||||
requires_auth: true,
|
||||
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
code_server_args: (&ctx.args).into(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
pub async fn service(
|
||||
ctx: CommandContext,
|
||||
service_args: TunnelServiceSubCommands,
|
||||
|
|
|
@ -18,7 +18,8 @@ pub const CONTROL_PORT: u16 = 31545;
|
|||
/// 2 - Addition of `serve.compressed` property to control whether servermsg's
|
||||
/// are compressed bidirectionally.
|
||||
/// 3 - The server's connection token is set to a SHA256 hash of the tunnel ID
|
||||
pub const PROTOCOL_VERSION: u32 = 3;
|
||||
/// 4 - The server's msgpack messages are no longer length-prefixed
|
||||
pub const PROTOCOL_VERSION: u32 = 4;
|
||||
|
||||
/// Prefix for the tunnel tag that includes the version.
|
||||
pub const PROTOCOL_VERSION_TAG_PREFIX: &str = "protocolv";
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use bytes::Buf;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader},
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
pin,
|
||||
sync::mpsc,
|
||||
};
|
||||
|
@ -18,7 +19,7 @@ use crate::{
|
|||
sync::{Barrier, Receivable},
|
||||
},
|
||||
};
|
||||
use std::io;
|
||||
use std::io::{self, Cursor, ErrorKind};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MsgPackSerializer {}
|
||||
|
@ -35,21 +36,28 @@ impl Serialization for MsgPackSerializer {
|
|||
|
||||
pub type MsgPackCaller = rpc::RpcCaller<MsgPackSerializer>;
|
||||
|
||||
/// Creates a new RPC Builder that serializes to JSON.
|
||||
/// Creates a new RPC Builder that serializes to msgpack.
|
||||
pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
|
||||
rpc::RpcBuilder::new(MsgPackSerializer {})
|
||||
}
|
||||
|
||||
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
|
||||
dispatcher: rpc::RpcDispatcher<MsgPackSerializer, C>,
|
||||
read: impl AsyncRead + Unpin,
|
||||
mut write: impl AsyncWrite + Unpin,
|
||||
/// Starting processing msgpack rpc over the given i/o. It's recommended that
|
||||
/// the reader be passed in as a BufReader for efficiency.
|
||||
pub async fn start_msgpack_rpc<
|
||||
C: Send + Sync + 'static,
|
||||
X: Clone,
|
||||
S: Send + Sync + Serialization,
|
||||
Read: AsyncRead + Unpin,
|
||||
Write: AsyncWrite + Unpin,
|
||||
>(
|
||||
dispatcher: rpc::RpcDispatcher<S, C>,
|
||||
mut read: Read,
|
||||
mut write: Write,
|
||||
mut msg_rx: impl Receivable<Vec<u8>>,
|
||||
mut shutdown_rx: Barrier<S>,
|
||||
) -> io::Result<Option<S>> {
|
||||
mut shutdown_rx: Barrier<X>,
|
||||
) -> io::Result<(Option<X>, Read, Write)> {
|
||||
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
|
||||
let mut read = BufReader::new(read);
|
||||
let mut decoder = U32PrefixedCodec {};
|
||||
let mut decoder = MsgPackCodec::new();
|
||||
let mut decoder_buf = bytes::BytesMut::new();
|
||||
|
||||
let shutdown_fut = shutdown_rx.wait();
|
||||
|
@ -61,7 +69,7 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
|
|||
r?;
|
||||
|
||||
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
|
||||
match dispatcher.dispatch(&frame) {
|
||||
match dispatcher.dispatch_with_partial(&frame.vec, frame.obj) {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
let _ = write_tx.send(v).await;
|
||||
},
|
||||
|
@ -94,39 +102,94 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
|
|||
Some(m) = msg_rx.recv_msg() => {
|
||||
write.write_all(&m).await?;
|
||||
},
|
||||
r = &mut shutdown_fut => return Ok(r.ok()),
|
||||
r = &mut shutdown_fut => return Ok((r.ok(), read, write)),
|
||||
}
|
||||
|
||||
write.flush().await?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reader that reads length-prefixed msgpack messages in a cancellation-safe
|
||||
/// way using Tokio's codecs.
|
||||
pub struct U32PrefixedCodec {}
|
||||
/// Reader that reads msgpack object messages in a cancellation-safe way using Tokio's codecs.
|
||||
///
|
||||
/// rmp_serde does not support async reads, and does not plan to. But we know every
|
||||
/// type in protocol is some kind of object, so by asking to deserialize the
|
||||
/// requested object from a reader (repeatedly, if incomplete) we can
|
||||
/// accomplish streaming.
|
||||
pub struct MsgPackCodec<T> {
|
||||
_marker: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
const U32_SIZE: usize = 4;
|
||||
impl<T> MsgPackCodec<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_marker: std::marker::PhantomData::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Decoder for U32PrefixedCodec {
|
||||
type Item = Vec<u8>;
|
||||
pub struct MsgPackDecoded<T> {
|
||||
pub obj: T,
|
||||
pub vec: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<T: DeserializeOwned> tokio_util::codec::Decoder for MsgPackCodec<T> {
|
||||
type Item = MsgPackDecoded<T>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if src.len() < 4 {
|
||||
src.reserve(U32_SIZE - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes_ref = src.as_ref();
|
||||
let mut cursor = Cursor::new(bytes_ref);
|
||||
|
||||
let mut be_bytes = [0; U32_SIZE];
|
||||
be_bytes.copy_from_slice(&src[..U32_SIZE]);
|
||||
let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize);
|
||||
if src.len() < required_len {
|
||||
src.reserve(required_len - src.len());
|
||||
return Ok(None);
|
||||
match rmp_serde::decode::from_read::<_, T>(&mut cursor) {
|
||||
Err(
|
||||
rmp_serde::decode::Error::InvalidDataRead(e)
|
||||
| rmp_serde::decode::Error::InvalidMarkerRead(e),
|
||||
) if e.kind() == ErrorKind::UnexpectedEof => {
|
||||
src.reserve(1024);
|
||||
Ok(None)
|
||||
}
|
||||
Err(e) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
e.to_string(),
|
||||
)),
|
||||
Ok(obj) => {
|
||||
let len = cursor.position() as usize;
|
||||
let vec = src[..len].to_vec();
|
||||
src.advance(len);
|
||||
Ok(Some(MsgPackDecoded { obj, vec }))
|
||||
}
|
||||
}
|
||||
|
||||
let msg = src[U32_SIZE..required_len].to_vec();
|
||||
src.advance(required_len);
|
||||
Ok(Some(msg))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
|
||||
pub struct Msg {
|
||||
pub x: i32,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol() {
|
||||
let mut c = MsgPackCodec::<Msg>::new();
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
assert!(c.decode(&mut buf).unwrap().is_none());
|
||||
|
||||
buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 1 }).unwrap().as_slice());
|
||||
buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 2 }).unwrap().as_slice());
|
||||
|
||||
assert_eq!(
|
||||
c.decode(&mut buf).unwrap().expect("expected msg1").obj,
|
||||
Msg { x: 1 }
|
||||
);
|
||||
assert_eq!(
|
||||
c.decode(&mut buf).unwrap().expect("expected msg1").obj,
|
||||
Msg { x: 2 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,6 +104,10 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
|||
R: Serialize,
|
||||
F: Fn(P, &C) -> Result<R, AnyError> + Send + Sync + 'static,
|
||||
{
|
||||
if self.methods.contains_key(method_name) {
|
||||
panic!("Method already registered: {}", method_name);
|
||||
}
|
||||
|
||||
let serial = self.serializer.clone();
|
||||
let context = self.context.clone();
|
||||
self.methods.insert(
|
||||
|
@ -276,7 +280,9 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
|||
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
|
||||
let s1 = s1.clone();
|
||||
async move {
|
||||
s1.lock().await.remove(&m.stream);
|
||||
if let Some(mut s) = s1.lock().await.remove(&m.stream) {
|
||||
let _ = s.shutdown().await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
@ -410,13 +416,17 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
|||
/// The future or return result will be optional bytes that should be sent
|
||||
/// back to the socket.
|
||||
pub fn dispatch(&self, body: &[u8]) -> MaybeSync {
|
||||
let partial = match self.serializer.deserialize::<PartialIncoming>(body) {
|
||||
Ok(b) => b,
|
||||
match self.serializer.deserialize::<PartialIncoming>(body) {
|
||||
Ok(partial) => self.dispatch_with_partial(body, partial),
|
||||
Err(_err) => {
|
||||
warning!(self.log, "Failed to deserialize request, hex: {:X?}", body);
|
||||
return MaybeSync::Sync(None);
|
||||
MaybeSync::Sync(None)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Like dispatch, but allows passing an existing PartialIncoming.
|
||||
pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync {
|
||||
let id = partial.id;
|
||||
|
||||
if let Some(method_name) = partial.method {
|
||||
|
@ -536,8 +546,8 @@ trait AssertIsSync: Sync {}
|
|||
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
|
||||
|
||||
/// Approximate shape that is used to determine what kind of data is incoming.
|
||||
#[derive(Deserialize)]
|
||||
struct PartialIncoming {
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PartialIncoming {
|
||||
pub id: Option<u32>,
|
||||
pub method: Option<String>,
|
||||
pub error: Option<ResponseError>,
|
||||
|
|
|
@ -7,11 +7,12 @@ pub mod code_server;
|
|||
pub mod dev_tunnels;
|
||||
pub mod legal;
|
||||
pub mod paths;
|
||||
pub mod protocol;
|
||||
pub mod shutdown_signal;
|
||||
pub mod singleton_client;
|
||||
pub mod singleton_server;
|
||||
pub mod protocol;
|
||||
|
||||
mod challenge;
|
||||
mod control_server;
|
||||
mod nosleep;
|
||||
#[cfg(target_os = "linux")]
|
||||
|
@ -31,11 +32,9 @@ mod service_macos;
|
|||
#[cfg(target_os = "windows")]
|
||||
mod service_windows;
|
||||
mod socket_signal;
|
||||
mod wsl_server;
|
||||
|
||||
pub use control_server::{serve, Next};
|
||||
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
|
||||
pub use nosleep::SleepInhibitor;
|
||||
pub use service::{
|
||||
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,
|
||||
};
|
||||
pub use wsl_server::serve_wsl;
|
||||
|
|
41
cli/src/tunnels/challenge.rs
Normal file
41
cli/src/tunnels/challenge.rs
Normal file
|
@ -0,0 +1,41 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn create_challenge() -> String {
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
Alphanumeric.sample_string(&mut rand::thread_rng(), 16)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn sign_challenge(challenge: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hash = Sha256::new();
|
||||
hash.update(challenge.as_bytes());
|
||||
let result = hash.finalize();
|
||||
base64::encode_config(result, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn verify_challenge(challenge: &str, response: &str) -> bool {
|
||||
sign_challenge(challenge) == response
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn create_challenge() -> String {
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
let str = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||
vsda::create_new_message(&str)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn sign_challenge(challenge: &str) -> String {
|
||||
vsda::sign(challenge)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn verify_challenge(challenge: &str, response: &str) -> bool {
|
||||
vsda::validate(challenge, response)
|
||||
}
|
|
@ -5,16 +5,15 @@
|
|||
use crate::async_pipe::get_socket_rw_stream;
|
||||
use crate::constants::{CONTROL_PORT, PRODUCT_NAME_LONG};
|
||||
use crate::log;
|
||||
use crate::msgpack_rpc::U32PrefixedCodec;
|
||||
use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization};
|
||||
use crate::msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCodec, MsgPackSerializer};
|
||||
use crate::rpc::{MaybeSync, RpcBuilder, RpcCaller, RpcDispatcher};
|
||||
use crate::self_update::SelfUpdate;
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::tunnels::protocol::HttpRequestParams;
|
||||
use crate::tunnels::protocol::{HttpRequestParams, METHOD_CHALLENGE_ISSUE};
|
||||
use crate::tunnels::socket_signal::CloseReason;
|
||||
use crate::update_service::{Platform, Release, TargetKind, UpdateService};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError,
|
||||
NoAttachedServerError,
|
||||
wrap, AnyError, CodeError, MismatchedLaunchModeError, NoAttachedServerError,
|
||||
};
|
||||
use crate::util::http::{
|
||||
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
|
||||
|
@ -22,7 +21,7 @@ use crate::util::http::{
|
|||
use crate::util::io::SilentCopyProgress;
|
||||
use crate::util::is_integrated_cli;
|
||||
use crate::util::os::os_release;
|
||||
use crate::util::sync::{new_barrier, Barrier};
|
||||
use crate::util::sync::{new_barrier, Barrier, BarrierOpener};
|
||||
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::FutureExt;
|
||||
|
@ -31,6 +30,7 @@ use opentelemetry::KeyValue;
|
|||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use tokio::pin;
|
||||
use tokio::process::{ChildStderr, ChildStdin};
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
|
||||
|
@ -39,6 +39,7 @@ use std::time::Instant;
|
|||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
use super::challenge::{create_challenge, sign_challenge, verify_challenge};
|
||||
use super::code_server::{
|
||||
download_cli_into_cache, AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw,
|
||||
SocketCodeServer,
|
||||
|
@ -47,11 +48,12 @@ use super::dev_tunnels::ActiveTunnel;
|
|||
use super::paths::prune_stopped_servers;
|
||||
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
|
||||
use super::protocol::{
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject,
|
||||
ForwardParams, ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse,
|
||||
GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog,
|
||||
ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams,
|
||||
UpdateResult, VersionParams,
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
|
||||
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
|
||||
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
|
||||
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
|
||||
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
|
||||
METHOD_CHALLENGE_VERIFY,
|
||||
};
|
||||
use super::server_bridge::ServerBridge;
|
||||
use super::server_multiplexer::ServerMultiplexer;
|
||||
|
@ -68,6 +70,8 @@ struct HandlerContext {
|
|||
log: log::Logger,
|
||||
/// Whether the server update during the handler session.
|
||||
did_update: Arc<AtomicBool>,
|
||||
/// Whether authentication is still required on the socket.
|
||||
auth_state: Arc<std::sync::Mutex<AuthState>>,
|
||||
/// A loopback channel to talk to the socket server task.
|
||||
socket_tx: mpsc::Sender<SocketSignal>,
|
||||
/// Configured launcher paths.
|
||||
|
@ -79,7 +83,7 @@ struct HandlerContext {
|
|||
// the cli arguments used to start the code server
|
||||
code_server_args: CodeServerArgs,
|
||||
/// port forwarding functionality
|
||||
port_forwarding: PortForwarding,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
/// install platform for the VS Code server
|
||||
platform: Platform,
|
||||
/// http client to make download/update requests
|
||||
|
@ -88,6 +92,16 @@ struct HandlerContext {
|
|||
http_requests: HttpRequestsMap,
|
||||
}
|
||||
|
||||
/// Handler auth state.
|
||||
enum AuthState {
|
||||
/// Auth is required, we're waiting for the client to send its challenge.
|
||||
WaitingForChallenge,
|
||||
/// A challenge has been issued. Waiting for a verification.
|
||||
ChallengeIssued(String),
|
||||
/// Auth is no longer required.
|
||||
Authenticated,
|
||||
}
|
||||
|
||||
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
// Gets a next incrementing number that can be used in logs
|
||||
|
@ -195,7 +209,14 @@ pub async fn serve(
|
|||
debug!(own_log, "Serving new connection");
|
||||
|
||||
let (writehalf, readhalf) = socket.into_split();
|
||||
let stats = process_socket(own_exit, readhalf, writehalf, own_log, own_tx, own_paths, own_code_server_args, own_forwarding, platform).with_context(cx.clone()).await;
|
||||
let stats = process_socket(readhalf, writehalf, own_tx, Some(own_forwarding), ServeStreamParams {
|
||||
log: own_log,
|
||||
launcher_paths: own_paths,
|
||||
code_server_args: own_code_server_args,
|
||||
platform,
|
||||
exit_barrier: own_exit,
|
||||
requires_auth: false,
|
||||
}).with_context(cx.clone()).await;
|
||||
|
||||
cx.span().add_event(
|
||||
"socket.bandwidth",
|
||||
|
@ -206,69 +227,91 @@ pub async fn serve(
|
|||
],
|
||||
);
|
||||
cx.span().end();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SocketStats {
|
||||
pub struct ServeStreamParams {
|
||||
pub log: log::Logger,
|
||||
pub launcher_paths: LauncherPaths,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
pub platform: Platform,
|
||||
pub requires_auth: bool,
|
||||
pub exit_barrier: Barrier<ShutdownSignal>,
|
||||
}
|
||||
|
||||
pub async fn serve_stream(
|
||||
readhalf: impl AsyncRead + Send + Unpin + 'static,
|
||||
writehalf: impl AsyncWrite + Unpin,
|
||||
params: ServeStreamParams,
|
||||
) -> SocketStats {
|
||||
// Currently the only server signal is respawn, that doesn't have much meaning
|
||||
// when serving a stream, so make an ignored channel.
|
||||
let (server_rx, server_tx) = mpsc::channel(1);
|
||||
drop(server_tx);
|
||||
|
||||
process_socket(readhalf, writehalf, server_rx, None, params).await
|
||||
}
|
||||
|
||||
pub struct SocketStats {
|
||||
rx: usize,
|
||||
tx: usize,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct MsgPackSerializer {}
|
||||
|
||||
impl Serialization for MsgPackSerializer {
|
||||
fn serialize(&self, value: impl serde::Serialize) -> Vec<u8> {
|
||||
rmp_serde::to_vec_named(&value).expect("expected to serialize")
|
||||
}
|
||||
|
||||
fn deserialize<P: serde::de::DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError> {
|
||||
rmp_serde::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)] // necessary here
|
||||
async fn process_socket(
|
||||
mut exit_barrier: Barrier<()>,
|
||||
readhalf: impl AsyncRead + Send + Unpin + 'static,
|
||||
mut writehalf: impl AsyncWrite + Unpin,
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn make_socket_rpc(
|
||||
log: log::Logger,
|
||||
server_tx: mpsc::Sender<ServerSignal>,
|
||||
socket_tx: mpsc::Sender<SocketSignal>,
|
||||
http_delegated: DelegatedSimpleHttp,
|
||||
launcher_paths: LauncherPaths,
|
||||
code_server_args: CodeServerArgs,
|
||||
port_forwarding: PortForwarding,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
requires_auth: bool,
|
||||
platform: Platform,
|
||||
) -> SocketStats {
|
||||
let (socket_tx, mut socket_rx) = mpsc::channel(4);
|
||||
let rx_counter = Arc::new(AtomicUsize::new(0));
|
||||
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
let server_bridges = ServerMultiplexer::new();
|
||||
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone());
|
||||
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
|
||||
did_update: Arc::new(AtomicBool::new(false)),
|
||||
socket_tx: socket_tx.clone(),
|
||||
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
|
||||
true => AuthState::WaitingForChallenge,
|
||||
false => AuthState::Authenticated,
|
||||
})),
|
||||
socket_tx,
|
||||
log: log.clone(),
|
||||
launcher_paths,
|
||||
code_server_args,
|
||||
code_server: Arc::new(Mutex::new(None)),
|
||||
server_bridges: server_bridges.clone(),
|
||||
server_bridges,
|
||||
port_forwarding,
|
||||
platform,
|
||||
http: Arc::new(FallbackSimpleHttp::new(
|
||||
ReqwestSimpleHttp::new(),
|
||||
http_delegated,
|
||||
)),
|
||||
http_requests: http_requests.clone(),
|
||||
http_requests,
|
||||
});
|
||||
|
||||
rpc.register_sync("ping", |_: EmptyObject, _| Ok(EmptyObject {}));
|
||||
rpc.register_sync("gethostname", |_: EmptyObject, _| handle_get_hostname());
|
||||
rpc.register_sync("fs_stat", |p: FsStatRequest, _| handle_stat(p.path));
|
||||
rpc.register_sync("get_env", |_: EmptyObject, _| handle_get_env());
|
||||
rpc.register_sync("fs_stat", |p: FsStatRequest, c| {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_stat(p.path)
|
||||
});
|
||||
rpc.register_sync("get_env", |_: EmptyObject, c| {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_get_env()
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
|
||||
handle_challenge_issue(&c.auth_state)
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
|
||||
handle_challenge_verify(p.response, &c.auth_state)
|
||||
});
|
||||
rpc.register_async("serve", move |params: ServeParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_serve(c, params).await
|
||||
});
|
||||
rpc.register_async("update", |p: UpdateParams, c| async move {
|
||||
|
@ -286,15 +329,19 @@ async fn process_socket(
|
|||
handle_call_server_http(code_server, p).await
|
||||
});
|
||||
rpc.register_async("forward", |p: ForwardParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_forward(&c.log, &c.port_forwarding, p).await
|
||||
});
|
||||
rpc.register_async("unforward", |p: UnforwardParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_unforward(&c.log, &c.port_forwarding, p).await
|
||||
});
|
||||
rpc.register_async("acquire_cli", |p: AcquireCliParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_acquire_cli(&c.launcher_paths, &c.http, &c.log, p).await
|
||||
});
|
||||
rpc.register_duplex("spawn", 3, |mut streams, p: SpawnParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_spawn(
|
||||
&c.log,
|
||||
p,
|
||||
|
@ -304,13 +351,28 @@ async fn process_socket(
|
|||
)
|
||||
.await
|
||||
});
|
||||
rpc.register_duplex(
|
||||
"spawn_cli",
|
||||
3,
|
||||
|mut streams, p: SpawnParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_spawn_cli(
|
||||
&c.log,
|
||||
p,
|
||||
streams.remove(0),
|
||||
streams.remove(0),
|
||||
streams.remove(0),
|
||||
)
|
||||
.await
|
||||
},
|
||||
);
|
||||
rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| {
|
||||
if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) {
|
||||
req.initial_response(p.status_code, p.headers);
|
||||
}
|
||||
Ok(EmptyObject {})
|
||||
});
|
||||
rpc.register_sync("unforward", move |p: HttpBodyParams, c| {
|
||||
rpc.register_sync("httpbody", move |p: HttpBodyParams, c| {
|
||||
let mut reqs = c.http_requests.lock().unwrap();
|
||||
if let Some(req) = reqs.get(&p.req_id) {
|
||||
if !p.segment.is_empty() {
|
||||
|
@ -322,15 +384,64 @@ async fn process_socket(
|
|||
}
|
||||
Ok(EmptyObject {})
|
||||
});
|
||||
rpc.register_sync(
|
||||
"version",
|
||||
|_: EmptyObject, _| Ok(VersionResponse::default()),
|
||||
);
|
||||
|
||||
rpc.build(log)
|
||||
}
|
||||
|
||||
fn ensure_auth(is_authed: &Arc<std::sync::Mutex<AuthState>>) -> Result<(), AnyError> {
|
||||
if let AuthState::Authenticated = &*is_authed.lock().unwrap() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CodeError::ServerAuthRequired.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)] // necessary here
|
||||
async fn process_socket(
|
||||
readhalf: impl AsyncRead + Send + Unpin + 'static,
|
||||
mut writehalf: impl AsyncWrite + Unpin,
|
||||
server_tx: mpsc::Sender<ServerSignal>,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
params: ServeStreamParams,
|
||||
) -> SocketStats {
|
||||
let ServeStreamParams {
|
||||
mut exit_barrier,
|
||||
log,
|
||||
launcher_paths,
|
||||
code_server_args,
|
||||
platform,
|
||||
requires_auth,
|
||||
} = params;
|
||||
|
||||
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone());
|
||||
let (socket_tx, mut socket_rx) = mpsc::channel(4);
|
||||
let rx_counter = Arc::new(AtomicUsize::new(0));
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
let rpc = make_socket_rpc(
|
||||
log.clone(),
|
||||
socket_tx.clone(),
|
||||
http_delegated,
|
||||
launcher_paths,
|
||||
code_server_args,
|
||||
port_forwarding,
|
||||
requires_auth,
|
||||
platform,
|
||||
);
|
||||
|
||||
{
|
||||
let log = log.clone();
|
||||
let rx_counter = rx_counter.clone();
|
||||
let socket_tx = socket_tx.clone();
|
||||
let exit_barrier = exit_barrier.clone();
|
||||
let rpc = rpc.build(log.clone());
|
||||
tokio::spawn(async move {
|
||||
send_version(&socket_tx).await;
|
||||
if !requires_auth {
|
||||
send_version(&socket_tx).await;
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
handle_socket_read(&log, readhalf, exit_barrier, &socket_tx, rx_counter, &rpc).await
|
||||
|
@ -350,6 +461,10 @@ async fn process_socket(
|
|||
}
|
||||
|
||||
ctx.dispose().await;
|
||||
|
||||
let _ = socket_tx
|
||||
.send(SocketSignal::CloseWith(CloseReason("eof".to_string())))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -408,7 +523,7 @@ async fn process_socket(
|
|||
async fn send_version(tx: &mpsc::Sender<SocketSignal>) {
|
||||
tx.send(SocketSignal::from_message(&ToClientRequest {
|
||||
id: None,
|
||||
params: ClientRequestMethod::version(VersionParams::default()),
|
||||
params: ClientRequestMethod::version(VersionResponse::default()),
|
||||
}))
|
||||
.await
|
||||
.ok();
|
||||
|
@ -416,13 +531,13 @@ async fn send_version(tx: &mpsc::Sender<SocketSignal>) {
|
|||
async fn handle_socket_read(
|
||||
_log: &log::Logger,
|
||||
readhalf: impl AsyncRead + Unpin,
|
||||
mut closer: Barrier<()>,
|
||||
mut closer: Barrier<ShutdownSignal>,
|
||||
socket_tx: &mpsc::Sender<SocketSignal>,
|
||||
rx_counter: Arc<AtomicUsize>,
|
||||
rpc: &RpcDispatcher<MsgPackSerializer, HandlerContext>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
let mut readhalf = BufReader::new(readhalf);
|
||||
let mut decoder = U32PrefixedCodec {};
|
||||
let mut decoder = MsgPackCodec::new();
|
||||
let mut decoder_buf = bytes::BytesMut::new();
|
||||
|
||||
loop {
|
||||
|
@ -431,10 +546,14 @@ async fn handle_socket_read(
|
|||
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
|
||||
}?;
|
||||
|
||||
if read_len == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
rx_counter.fetch_add(read_len, Ordering::Relaxed);
|
||||
|
||||
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
|
||||
match rpc.dispatch(&frame) {
|
||||
match rpc.dispatch_with_partial(&frame.vec, frame.obj) {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
if socket_tx.send(SocketSignal::Send(v)).await.is_err() {
|
||||
return Ok(());
|
||||
|
@ -704,11 +823,44 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
|
|||
})
|
||||
}
|
||||
|
||||
fn handle_challenge_issue(
|
||||
auth_state: &Arc<std::sync::Mutex<AuthState>>,
|
||||
) -> Result<ChallengeIssueResponse, AnyError> {
|
||||
let challenge = create_challenge();
|
||||
|
||||
let mut auth_state = auth_state.lock().unwrap();
|
||||
*auth_state = AuthState::ChallengeIssued(challenge.clone());
|
||||
|
||||
Ok(ChallengeIssueResponse { challenge })
|
||||
}
|
||||
|
||||
fn handle_challenge_verify(
|
||||
response: String,
|
||||
auth_state: &Arc<std::sync::Mutex<AuthState>>,
|
||||
) -> Result<EmptyObject, AnyError> {
|
||||
let mut auth_state = auth_state.lock().unwrap();
|
||||
|
||||
match &*auth_state {
|
||||
AuthState::Authenticated => Ok(EmptyObject {}),
|
||||
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
|
||||
false => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
true => {
|
||||
*auth_state = AuthState::Authenticated;
|
||||
Ok(EmptyObject {})
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_forward(
|
||||
log: &log::Logger,
|
||||
port_forwarding: &PortForwarding,
|
||||
port_forwarding: &Option<PortForwarding>,
|
||||
params: ForwardParams,
|
||||
) -> Result<ForwardResult, AnyError> {
|
||||
let port_forwarding = port_forwarding
|
||||
.as_ref()
|
||||
.ok_or(CodeError::PortForwardingNotAvailable)?;
|
||||
info!(log, "Forwarding port {}", params.port);
|
||||
let uri = port_forwarding.forward(params.port).await?;
|
||||
Ok(ForwardResult { uri })
|
||||
|
@ -716,9 +868,12 @@ async fn handle_forward(
|
|||
|
||||
async fn handle_unforward(
|
||||
log: &log::Logger,
|
||||
port_forwarding: &PortForwarding,
|
||||
port_forwarding: &Option<PortForwarding>,
|
||||
params: UnforwardParams,
|
||||
) -> Result<EmptyObject, AnyError> {
|
||||
let port_forwarding = port_forwarding
|
||||
.as_ref()
|
||||
.ok_or(CodeError::PortForwardingNotAvailable)?;
|
||||
info!(log, "Unforwarding port {}", params.port);
|
||||
port_forwarding.unforward(params.port).await?;
|
||||
Ok(EmptyObject {})
|
||||
|
@ -818,17 +973,17 @@ async fn handle_spawn<Stdin, StdoutAndErr>(
|
|||
stderr: Option<StdoutAndErr>,
|
||||
) -> Result<SpawnResult, AnyError>
|
||||
where
|
||||
Stdin: AsyncRead + Unpin + Send,
|
||||
StdoutAndErr: AsyncWrite + Unpin + Send,
|
||||
Stdin: AsyncRead + Unpin + Send + 'static,
|
||||
StdoutAndErr: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
debug!(
|
||||
log,
|
||||
"requested to spawn {} with args {:?}", params.command, params.args
|
||||
);
|
||||
|
||||
macro_rules! pipe_if_some {
|
||||
macro_rules! pipe_if {
|
||||
($e: expr) => {
|
||||
if $e.is_some() {
|
||||
if $e {
|
||||
Stdio::piped()
|
||||
} else {
|
||||
Stdio::null()
|
||||
|
@ -839,9 +994,9 @@ where
|
|||
let mut p = tokio::process::Command::new(¶ms.command);
|
||||
p.args(¶ms.args);
|
||||
p.envs(¶ms.env);
|
||||
p.stdin(pipe_if_some!(stdin));
|
||||
p.stdout(pipe_if_some!(stdout));
|
||||
p.stderr(pipe_if_some!(stderr));
|
||||
p.stdin(pipe_if!(stdin.is_some()));
|
||||
p.stdout(pipe_if!(stdin.is_some()));
|
||||
p.stderr(pipe_if!(stderr.is_some()));
|
||||
if let Some(cwd) = ¶ms.cwd {
|
||||
p.current_dir(cwd);
|
||||
}
|
||||
|
@ -859,7 +1014,72 @@ where
|
|||
futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed());
|
||||
}
|
||||
|
||||
let closed = p.wait();
|
||||
wait_for_process_exit(log, ¶ms.command, p, futs).await
|
||||
}
|
||||
|
||||
async fn handle_spawn_cli(
|
||||
log: &log::Logger,
|
||||
params: SpawnParams,
|
||||
mut protocol_in: DuplexStream,
|
||||
mut protocol_out: DuplexStream,
|
||||
mut log_out: DuplexStream,
|
||||
) -> Result<SpawnResult, AnyError> {
|
||||
debug!(
|
||||
log,
|
||||
"requested to spawn cli {} with args {:?}", params.command, params.args
|
||||
);
|
||||
|
||||
let mut p = tokio::process::Command::new(¶ms.command);
|
||||
p.args(¶ms.args);
|
||||
|
||||
// CLI args to spawn a server; contracted with clients that they should _not_ provide these.
|
||||
p.arg("--verbose");
|
||||
p.arg("tunnel");
|
||||
p.arg("stdio");
|
||||
|
||||
p.envs(¶ms.env);
|
||||
p.stdin(Stdio::piped());
|
||||
p.stdout(Stdio::piped());
|
||||
p.stderr(Stdio::piped());
|
||||
if let Some(cwd) = ¶ms.cwd {
|
||||
p.current_dir(cwd);
|
||||
}
|
||||
|
||||
let mut p = p.spawn().map_err(CodeError::ProcessSpawnFailed)?;
|
||||
|
||||
let mut stdin = p.stdin.take().unwrap();
|
||||
let mut stdout = p.stdout.take().unwrap();
|
||||
let mut stderr = p.stderr.take().unwrap();
|
||||
|
||||
// Start handling logs while doing the handshake in case there's some kind of error
|
||||
let log_pump = tokio::spawn(async move { tokio::io::copy(&mut stdout, &mut log_out).await });
|
||||
|
||||
// note: intentionally do not wrap stdin in a bufreader, since we don't
|
||||
// want to read anything other than our handshake messages.
|
||||
if let Err(e) = spawn_do_child_authentication(log, &mut stdin, &mut stderr).await {
|
||||
warning!(log, "failed to authenticate with child process {}", e);
|
||||
let _ = p.kill().await;
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
debug!(log, "cli authenticated, attaching stdio");
|
||||
let futs = FuturesUnordered::new();
|
||||
futs.push(async move { tokio::io::copy(&mut protocol_in, &mut stdin).await }.boxed());
|
||||
futs.push(async move { tokio::io::copy(&mut stderr, &mut protocol_out).await }.boxed());
|
||||
futs.push(async move { log_pump.await.unwrap() }.boxed());
|
||||
|
||||
wait_for_process_exit(log, ¶ms.command, p, futs).await
|
||||
}
|
||||
|
||||
type TokioCopyFuture = dyn futures::Future<Output = Result<u64, std::io::Error>> + Send;
|
||||
|
||||
async fn wait_for_process_exit(
|
||||
log: &log::Logger,
|
||||
command: &str,
|
||||
mut process: tokio::process::Child,
|
||||
futs: FuturesUnordered<std::pin::Pin<Box<TokioCopyFuture>>>,
|
||||
) -> Result<SpawnResult, AnyError> {
|
||||
let closed = process.wait();
|
||||
pin!(closed);
|
||||
|
||||
let r = tokio::select! {
|
||||
|
@ -880,8 +1100,69 @@ where
|
|||
|
||||
debug!(
|
||||
log,
|
||||
"spawned command {} exited with code {}", params.command, r.exit_code
|
||||
"spawned cli {} exited with code {}", command, r.exit_code
|
||||
);
|
||||
|
||||
Ok(r)
|
||||
}
|
||||
|
||||
async fn spawn_do_child_authentication(
|
||||
log: &log::Logger,
|
||||
stdin: &mut ChildStdin,
|
||||
stdout: &mut ChildStderr,
|
||||
) -> Result<(), CodeError> {
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let (shutdown_rx, shutdown) = new_barrier();
|
||||
let mut rpc = new_msgpack_rpc();
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
|
||||
let challenge_response = do_challenge_response_flow(caller, shutdown);
|
||||
let rpc = start_msgpack_rpc(
|
||||
rpc.methods(()).build(log.prefixed("client-auth")),
|
||||
stdout,
|
||||
stdin,
|
||||
msg_rx,
|
||||
shutdown_rx,
|
||||
);
|
||||
pin!(rpc);
|
||||
|
||||
tokio::select! {
|
||||
r = &mut rpc => {
|
||||
match r {
|
||||
// means shutdown happened cleanly already, we're good
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(CodeError::ProcessSpawnHandshakeFailed(e))
|
||||
}
|
||||
},
|
||||
r = challenge_response => {
|
||||
r?;
|
||||
rpc.await.map(|_| ()).map_err(CodeError::ProcessSpawnFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_challenge_response_flow(
|
||||
caller: RpcCaller<MsgPackSerializer>,
|
||||
shutdown: BarrierOpener<()>,
|
||||
) -> Result<(), CodeError> {
|
||||
let challenge: ChallengeIssueResponse = caller
|
||||
.call(METHOD_CHALLENGE_ISSUE, EmptyObject {})
|
||||
.await
|
||||
.unwrap()
|
||||
.map_err(CodeError::TunnelRpcCallFailed)?;
|
||||
|
||||
let _: EmptyObject = caller
|
||||
.call(
|
||||
METHOD_CHALLENGE_VERIFY,
|
||||
ChallengeVerifyParams {
|
||||
response: sign_challenge(&challenge.challenge),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.map_err(CodeError::TunnelRpcCallFailed)?;
|
||||
|
||||
shutdown.open(());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ pub enum ClientRequestMethod<'a> {
|
|||
servermsg(RefServerMessageParams<'a>),
|
||||
serverlog(ServerLog<'a>),
|
||||
makehttpreq(HttpRequestParams<'a>),
|
||||
version(VersionParams),
|
||||
version(VersionResponse),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
|
@ -58,14 +58,6 @@ pub struct ForwardResult {
|
|||
pub uri: String,
|
||||
}
|
||||
|
||||
/// The `install_local` method in the wsl control server
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct InstallFromLocalFolderParams {
|
||||
pub archive_path: String,
|
||||
#[serde(flatten)]
|
||||
pub inner: ServeParams,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ServeParams {
|
||||
pub socket_id: u16,
|
||||
|
@ -165,12 +157,12 @@ pub struct CallServerHttpResult {
|
|||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct VersionParams {
|
||||
pub struct VersionResponse {
|
||||
pub version: &'static str,
|
||||
pub protocol_version: u32,
|
||||
}
|
||||
|
||||
impl Default for VersionParams {
|
||||
impl Default for VersionResponse {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
version: VSCODE_CLI_VERSION.unwrap_or("dev"),
|
||||
|
@ -204,6 +196,19 @@ pub struct SpawnResult {
|
|||
pub exit_code: i32,
|
||||
}
|
||||
|
||||
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
|
||||
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueResponse {
|
||||
pub challenge: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct ChallengeVerifyParams {
|
||||
pub response: String,
|
||||
}
|
||||
|
||||
pub mod singleton {
|
||||
use crate::log;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
@ -74,7 +74,6 @@ pub async fn start_singleton_client(args: SingletonClientArgs) -> bool {
|
|||
let mut input = String::new();
|
||||
loop {
|
||||
input.truncate(0);
|
||||
println!("reading line");
|
||||
match std::io::stdin().read_line(&mut input) {
|
||||
Err(_) | Ok(0) => return, // EOF or not a tty
|
||||
_ => {}
|
||||
|
|
|
@ -38,6 +38,7 @@ impl SocketSignal {
|
|||
}
|
||||
|
||||
/// todo@connor4312: cleanup once everything is moved to rpc standard interfaces
|
||||
#[allow(dead_code)]
|
||||
pub enum ServerMessageDestination {
|
||||
Channel(mpsc::Sender<SocketSignal>),
|
||||
Rpc(MsgPackCaller),
|
||||
|
|
|
@ -1,173 +0,0 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
log,
|
||||
msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCaller},
|
||||
state::LauncherPaths,
|
||||
tunnels::code_server::ServerBuilder,
|
||||
update_service::{Platform, Release, TargetKind},
|
||||
util::{
|
||||
errors::{
|
||||
wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError,
|
||||
},
|
||||
http::ReqwestSimpleHttp,
|
||||
sync::Barrier,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{
|
||||
code_server::{AnyCodeServer, CodeServerArgs, ResolvedServerParams},
|
||||
protocol::{EmptyObject, InstallFromLocalFolderParams, ServerMessageParams, VersionParams},
|
||||
server_bridge::ServerBridge,
|
||||
server_multiplexer::ServerMultiplexer,
|
||||
shutdown_signal::ShutdownSignal,
|
||||
socket_signal::{ClientMessageDecoder, ServerMessageDestination, ServerMessageSink},
|
||||
};
|
||||
|
||||
struct HandlerContext {
|
||||
log: log::Logger,
|
||||
code_server_args: CodeServerArgs,
|
||||
launcher_paths: LauncherPaths,
|
||||
platform: Platform,
|
||||
http: ReqwestSimpleHttp,
|
||||
caller: MsgPackCaller,
|
||||
multiplexer: ServerMultiplexer,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RpcLogSink(MsgPackCaller);
|
||||
|
||||
impl RpcLogSink {
|
||||
fn write_json(&self, level: String, message: &str) {
|
||||
self.0.notify(
|
||||
"log",
|
||||
serde_json::json!({
|
||||
"level": level,
|
||||
"message": message,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl log::LogSink for RpcLogSink {
|
||||
fn write_log(&self, level: log::Level, _prefix: &str, message: &str) {
|
||||
self.write_json(level.to_string(), message);
|
||||
}
|
||||
|
||||
fn write_result(&self, message: &str) {
|
||||
self.write_json("result".to_string(), message);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_wsl(
|
||||
log: log::Logger,
|
||||
launcher_paths: LauncherPaths,
|
||||
code_server_args: CodeServerArgs,
|
||||
platform: Platform,
|
||||
http: reqwest::Client,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> Result<i32, AnyError> {
|
||||
let (caller_tx, caller_rx) = mpsc::unbounded_channel();
|
||||
let mut rpc = new_msgpack_rpc();
|
||||
let caller = rpc.get_caller(caller_tx);
|
||||
|
||||
// notify the incoming client about the server version
|
||||
caller.notify("version", VersionParams::default());
|
||||
|
||||
let log = log.with_sink(RpcLogSink(caller.clone()));
|
||||
let mut rpc = rpc.methods(HandlerContext {
|
||||
log: log.clone(),
|
||||
caller,
|
||||
code_server_args,
|
||||
launcher_paths,
|
||||
platform,
|
||||
multiplexer: ServerMultiplexer::new(),
|
||||
http: ReqwestSimpleHttp::with_client(http),
|
||||
});
|
||||
|
||||
rpc.register_async(
|
||||
"serve",
|
||||
move |m: InstallFromLocalFolderParams, c| async move { handle_serve(&c, m).await },
|
||||
);
|
||||
rpc.register_sync("servermsg", move |m: ServerMessageParams, c| {
|
||||
if c.multiplexer.write_message(&c.log, m.i, m.body) {
|
||||
Ok(EmptyObject {})
|
||||
} else {
|
||||
Err(NoAttachedServerError().into())
|
||||
}
|
||||
});
|
||||
|
||||
start_msgpack_rpc(
|
||||
rpc.build(log),
|
||||
tokio::io::stdin(),
|
||||
tokio::io::stderr(),
|
||||
caller_rx,
|
||||
shutdown_rx,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error handling server stdio"))?;
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn handle_serve(
|
||||
c: &HandlerContext,
|
||||
params: InstallFromLocalFolderParams,
|
||||
) -> Result<EmptyObject, AnyError> {
|
||||
// fill params.extensions into code_server_args.install_extensions
|
||||
let mut csa = c.code_server_args.clone();
|
||||
csa.connection_token = params.inner.connection_token.or(csa.connection_token);
|
||||
csa.install_extensions
|
||||
.extend(params.inner.extensions.into_iter());
|
||||
|
||||
let resolved = ResolvedServerParams {
|
||||
code_server_args: csa,
|
||||
release: Release {
|
||||
name: String::new(),
|
||||
commit: params
|
||||
.inner
|
||||
.commit_id
|
||||
.ok_or_else(|| InvalidRpcDataError("commit_id is required".to_string()))?,
|
||||
platform: c.platform,
|
||||
target: TargetKind::Server,
|
||||
quality: params.inner.quality,
|
||||
},
|
||||
};
|
||||
|
||||
let sb = ServerBuilder::new(
|
||||
&c.log,
|
||||
&resolved,
|
||||
&c.launcher_paths,
|
||||
Arc::new(c.http.clone()),
|
||||
);
|
||||
let code_server = match sb.get_running().await? {
|
||||
Some(AnyCodeServer::Socket(s)) => s,
|
||||
Some(_) => return Err(MismatchedLaunchModeError().into()),
|
||||
None => {
|
||||
sb.setup().await?;
|
||||
sb.listen_on_default_socket().await?
|
||||
}
|
||||
};
|
||||
|
||||
let bridge = ServerBridge::new(
|
||||
&code_server.socket,
|
||||
ServerMessageSink::new_plain(
|
||||
c.multiplexer.clone(),
|
||||
params.inner.socket_id,
|
||||
ServerMessageDestination::Rpc(c.caller.clone()),
|
||||
),
|
||||
ClientMessageDecoder::new_plain(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
c.multiplexer.register(params.inner.socket_id, bridge);
|
||||
trace!(c.log, "Attached to server");
|
||||
Ok(EmptyObject {})
|
||||
}
|
|
@ -479,9 +479,18 @@ pub enum CodeError {
|
|||
PrerequisitesFailed { name: &'static str, bullets: String },
|
||||
#[error("failed to spawn process: {0:?}")]
|
||||
ProcessSpawnFailed(std::io::Error),
|
||||
|
||||
#[error("failed to handshake spawned process: {0:?}")]
|
||||
ProcessSpawnHandshakeFailed(std::io::Error),
|
||||
#[error("download appears corrupted, please retry ({0})")]
|
||||
CorruptDownload(&'static str),
|
||||
#[error("port forwarding is not available in this context")]
|
||||
PortForwardingNotAvailable,
|
||||
#[error("'auth' call required")]
|
||||
ServerAuthRequired,
|
||||
#[error("challenge not yet issued")]
|
||||
AuthChallengeNotIssued,
|
||||
#[error("unauthorized client refused")]
|
||||
AuthMismatch,
|
||||
}
|
||||
|
||||
makeAnyError!(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"name": "code-oss-dev",
|
||||
"version": "1.79.0",
|
||||
"distro": "10ec4d08d4a06a1c18addcded03e90c8a0e6ecad",
|
||||
"distro": "edc0b65674651d40e7bd668fbc675ddb30cec375",
|
||||
"author": {
|
||||
"name": "Microsoft Corporation"
|
||||
},
|
||||
|
|
Loading…
Reference in a new issue