Run tunnels as singleton process (for a --cli-data-dir) (#177002)

* wip on singleton

* wip

* windows support

* wip

* wip

* fix clippy
This commit is contained in:
Connor Peet 2023-03-14 08:09:47 -07:00 committed by GitHub
parent bed3a7761e
commit 1b5fd140fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 1283 additions and 311 deletions

View file

@ -60,6 +60,7 @@ steps:
VSCODE_CLI_ENV:
OPENSSL_LIB_DIR: $(Build.ArtifactStagingDirectory)/openssl/x64-windows-static-md/lib
OPENSSL_INCLUDE_DIR: $(Build.ArtifactStagingDirectory)/openssl/x64-windows-static-md/include
RUSTFLAGS: '-C target-feature=+crt-static'
- ${{ if eq(parameters.VSCODE_BUILD_WIN32_ARM64, true) }}:
- template: ../cli/cli-compile-and-publish.yml
@ -69,6 +70,7 @@ steps:
VSCODE_CLI_ENV:
OPENSSL_LIB_DIR: $(Build.ArtifactStagingDirectory)/openssl/arm64-windows-static-md/lib
OPENSSL_INCLUDE_DIR: $(Build.ArtifactStagingDirectory)/openssl/arm64-windows-static-md/include
RUSTFLAGS: '-C target-feature=+crt-static'
- ${{ if eq(parameters.VSCODE_BUILD_WIN32_32BIT, true) }}:
- template: ../cli/cli-compile-and-publish.yml
@ -78,3 +80,4 @@ steps:
VSCODE_CLI_ENV:
OPENSSL_LIB_DIR: $(Build.ArtifactStagingDirectory)/openssl/x86-windows-static-md/lib
OPENSSL_INCLUDE_DIR: $(Build.ArtifactStagingDirectory)/openssl/x86-windows-static-md/include
RUSTFLAGS: '-C target-feature=+crt-static'

View file

@ -1,2 +0,0 @@
[target.'cfg(all(windows, target_env = "msvc"))']
rustflags = ["-C", "target-feature=+crt-static"]

31
cli/Cargo.lock generated
View file

@ -230,6 +230,7 @@ dependencies = [
"async-trait",
"atty",
"base64",
"cfg-if",
"chrono",
"clap",
"clap_lex",
@ -249,6 +250,7 @@ dependencies = [
"open",
"opentelemetry",
"opentelemetry-application-insights",
"pin-project",
"rand 0.8.5",
"regex",
"reqwest",
@ -261,6 +263,7 @@ dependencies = [
"sysinfo",
"tar",
"tempfile",
"thiserror",
"tokio",
"tokio-util",
"tunnels",
@ -1533,6 +1536,26 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.9"
@ -2207,18 +2230,18 @@ checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16"
[[package]]
name = "thiserror"
version = "1.0.37"
version = "1.0.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
version = "1.0.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e"
dependencies = [
"proc-macro2",
"quote",

View file

@ -50,6 +50,9 @@ const_format = "0.2"
sha2 = "0.10"
base64 = "0.13"
shell-escape = "0.1.5"
thiserror = "1.0"
cfg-if = "1.0.0"
pin-project = "1.0"
[build-dependencies]
serde = { version = "1.0" }

183
cli/src/async_pipe.rs Normal file
View file

@ -0,0 +1,183 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
use std::path::{Path, PathBuf};
use uuid::Uuid;
// todo: we could probably abstract this into some crate, if one doesn't already exist
cfg_if::cfg_if! {
if #[cfg(unix)] {
pub type AsyncPipe = tokio::net::UnixStream;
pub type AsyncPipeWriteHalf = tokio::net::unix::OwnedWriteHalf;
pub type AsyncPipeReadHalf = tokio::net::unix::OwnedReadHalf;
pub async fn get_socket_rw_stream(path: &Path) -> Result<AsyncPipe, CodeError> {
tokio::net::UnixStream::connect(path)
.await
.map_err(CodeError::AsyncPipeFailed)
}
pub async fn listen_socket_rw_stream(path: &Path) -> Result<AsyncPipeListener, CodeError> {
tokio::net::UnixListener::bind(path)
.map(AsyncPipeListener)
.map_err(CodeError::AsyncPipeListenerFailed)
}
pub struct AsyncPipeListener(tokio::net::UnixListener);
impl AsyncPipeListener {
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
self.0.accept().await.map_err(CodeError::AsyncPipeListenerFailed).map(|(s, _)| s)
}
}
pub fn socket_stream_split(pipe: AsyncPipe) -> (AsyncPipeReadHalf, AsyncPipeWriteHalf) {
pipe.into_split()
}
} else {
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
use pin_project::pin_project;
#[pin_project(project = AsyncPipeProj)]
pub enum AsyncPipe {
PipeClient(#[pin] NamedPipeClient),
PipeServer(#[pin] NamedPipeServer),
}
impl AsyncRead for AsyncPipe {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.project() {
AsyncPipeProj::PipeClient(c) => c.poll_read(cx, buf),
AsyncPipeProj::PipeServer(c) => c.poll_read(cx, buf),
}
}
}
impl AsyncWrite for AsyncPipe {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.project() {
AsyncPipeProj::PipeClient(c) => c.poll_write(cx, buf),
AsyncPipeProj::PipeServer(c) => c.poll_write(cx, buf),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
match self.project() {
AsyncPipeProj::PipeClient(c) => c.poll_write_vectored(cx, bufs),
AsyncPipeProj::PipeServer(c) => c.poll_write_vectored(cx, bufs),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.project() {
AsyncPipeProj::PipeClient(c) => c.poll_flush(cx),
AsyncPipeProj::PipeServer(c) => c.poll_flush(cx),
}
}
fn is_write_vectored(&self) -> bool {
match self {
AsyncPipe::PipeClient(c) => c.is_write_vectored(),
AsyncPipe::PipeServer(c) => c.is_write_vectored(),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project() {
AsyncPipeProj::PipeClient(c) => c.poll_shutdown(cx),
AsyncPipeProj::PipeServer(c) => c.poll_shutdown(cx),
}
}
}
pub type AsyncPipeWriteHalf = tokio::io::WriteHalf<AsyncPipe>;
pub type AsyncPipeReadHalf = tokio::io::ReadHalf<AsyncPipe>;
pub async fn get_socket_rw_stream(path: &Path) -> Result<AsyncPipe, CodeError> {
// Tokio says we can need to try in a loop. Do so.
// https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
let client = loop {
match ClientOptions::new().open(path) {
Ok(client) => break client,
// ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499-
Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await,
Err(e) => return Err(CodeError::AsyncPipeFailed(e)),
}
};
Ok(AsyncPipe::PipeClient(client))
}
pub struct AsyncPipeListener {
path: PathBuf,
server: NamedPipeServer
}
impl AsyncPipeListener {
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
// see https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeServer.html
// this is a bit weird in that the server becomes the client once
// they get a connection, and we create a new client.
self.server
.connect()
.await
.map_err(CodeError::AsyncPipeListenerFailed)?;
// Construct the next server to be connected before sending the one
// we already have of onto a task. This ensures that the server
// isn't closed (after it's done in the task) before a new one is
// available. Otherwise the client might error with
// `io::ErrorKind::NotFound`.
let next_server = ServerOptions::new()
.create(&self.path)
.map_err(CodeError::AsyncPipeListenerFailed)?;
Ok(AsyncPipe::PipeServer(std::mem::replace(&mut self.server, next_server)))
}
}
pub async fn listen_socket_rw_stream(path: &Path) -> Result<AsyncPipeListener, CodeError> {
let server = ServerOptions::new()
.first_pipe_instance(true)
.create(path)
.map_err(CodeError::AsyncPipeListenerFailed)?;
Ok(AsyncPipeListener { path: path.to_owned(), server })
}
pub fn socket_stream_split(pipe: AsyncPipe) -> (AsyncPipeReadHalf, AsyncPipeWriteHalf) {
tokio::io::split(pipe)
}
}
}
/// Gets a random name for a pipe/socket on the paltform
pub fn get_socket_name() -> PathBuf {
cfg_if::cfg_if! {
if #[cfg(unix)] {
std::env::temp_dir().join(format!("{}-{}", APPLICATION_NAME, Uuid::new_v4()))
} else {
PathBuf::from(format!(r"\\.\pipe\{}-{}", APPLICATION_NAME, Uuid::new_v4()))
}
}
}

View file

@ -6,8 +6,8 @@
mod context;
pub mod args;
pub mod internal_wsl;
pub mod tunnels;
pub mod update;
pub mod version;
pub mod internal_wsl;
pub use context::CommandContext;

View file

@ -4,14 +4,14 @@
*--------------------------------------------------------------------------------------------*/
use crate::{
tunnels::{serve_wsl, shutdown_signal::ShutdownSignal},
tunnels::{serve_wsl, shutdown_signal::ShutdownRequest},
util::{errors::AnyError, prereqs::PreReqChecker},
};
use super::CommandContext;
pub async fn serve(ctx: CommandContext) -> Result<i32, AnyError> {
let signal = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]);
let signal = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]);
let platform = spanf!(
ctx.log,
ctx.log.span("prereq"),

View file

@ -5,9 +5,8 @@
use async_trait::async_trait;
use sha2::{Digest, Sha256};
use std::str::FromStr;
use std::{str::FromStr, time::Duration};
use sysinfo::Pid;
use tokio::sync::mpsc;
use super::{
args::{
@ -17,21 +16,31 @@ use super::{
CommandContext,
};
use crate::tunnels::shutdown_signal::ShutdownSignal;
use crate::tunnels::{dev_tunnels::ActiveTunnel, SleepInhibitor};
use crate::{
auth::Auth,
log::{self, Logger},
state::LauncherPaths,
tunnels::{
code_server::CodeServerArgs, create_service_manager, dev_tunnels, legal,
paths::get_all_servers, ServiceContainer, ServiceManager,
code_server::CodeServerArgs,
create_service_manager, dev_tunnels, legal,
paths::get_all_servers,
shutdown_signal::ShutdownRequest,
singleton_server::{start_singleton_server, SingletonServerArgs, BroadcastLogSink},
ServiceContainer, ServiceManager,
},
util::{
errors::{wrap, AnyError},
prereqs::PreReqChecker,
},
};
use crate::{
singleton::{acquire_singleton, SingletonConnection},
tunnels::{
dev_tunnels::ActiveTunnel,
singleton_client::{start_singleton_client, SingletonClientArgs},
SleepInhibitor,
},
};
impl From<AuthProvider> for crate::auth::AuthProvider {
fn from(auth_provider: AuthProvider) -> Self {
@ -75,7 +84,6 @@ impl ServiceContainer for TunnelServiceContainer {
&mut self,
log: log::Logger,
launcher_paths: LauncherPaths,
shutdown_rx: mpsc::UnboundedReceiver<ShutdownSignal>,
) -> Result<(), AnyError> {
let csa = (&self.args).into();
serve_with_csa(
@ -86,7 +94,6 @@ impl ServiceContainer for TunnelServiceContainer {
..Default::default()
},
csa,
Some(shutdown_rx),
)
.await?;
Ok(())
@ -227,7 +234,7 @@ pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result
legal::require_consent(&paths, gateway_args.accept_server_license_terms)?;
let csa = (&args).into();
let result = serve_with_csa(paths, log, gateway_args, csa, None).await;
let result = serve_with_csa(paths, log, gateway_args, csa).await;
drop(no_sleep);
result
@ -242,15 +249,52 @@ fn get_connection_token(tunnel: &ActiveTunnel) -> String {
async fn serve_with_csa(
paths: LauncherPaths,
log: Logger,
mut log: Logger,
gateway_args: TunnelServeArgs,
mut csa: CodeServerArgs,
shutdown_rx: Option<mpsc::UnboundedReceiver<ShutdownSignal>>,
) -> Result<i32, AnyError> {
let shutdown = match gateway_args
.parent_process_id
.and_then(|p| Pid::from_str(&p).ok())
{
Some(pid) => ShutdownRequest::create_rx([
ShutdownRequest::CtrlC,
ShutdownRequest::ParentProcessKilled(pid),
]),
None => ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
};
// Intentionally read before starting the server. If the server updated and
// respawn is requested, the old binary will get renamed, and then
// current_exe will point to the wrong path.
let current_exe = std::env::current_exe().unwrap();
let server = loop {
if shutdown.is_open() {
return Ok(0);
}
match acquire_singleton(paths.root().join("tunnel.lock")).await {
Ok(SingletonConnection::Client(stream)) => {
debug!(log, "starting as client to singleton");
start_singleton_client(SingletonClientArgs {
log: log.clone(),
shutdown: shutdown.clone(),
stream,
})
.await
}
Ok(SingletonConnection::Singleton(server)) => break server,
Err(e) => {
warning!(log, "error access singleton, retrying: {}", e);
tokio::time::sleep(Duration::from_secs(2)).await
}
}
};
debug!(log, "starting as new singleton");
let log_broadcast = BroadcastLogSink::new();
log = log.tee(log_broadcast.clone());
let platform = spanf!(log, log.span("prereq"), PreReqChecker::new().verify())?;
let auth = Auth::new(&paths, log.clone());
@ -264,21 +308,17 @@ async fn serve_with_csa(
csa.connection_token = Some(get_connection_token(&tunnel));
let shutdown_tx = if let Some(tx) = shutdown_rx {
tx
} else if let Some(pid) = gateway_args
.parent_process_id
.and_then(|p| Pid::from_str(&p).ok())
{
ShutdownSignal::create_rx(&[
ShutdownSignal::CtrlC,
ShutdownSignal::ParentProcessKilled(pid),
])
} else {
ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC])
};
let mut r = crate::tunnels::serve(&log, tunnel, &paths, &csa, platform, shutdown_tx).await?;
let mut r = start_singleton_server(SingletonServerArgs {
log: log.clone(),
tunnel,
paths,
code_server_args: csa,
platform,
log_broadcast,
shutdown,
server,
})
.await?;
r.tunnel.close().await.ok();
if r.respawn {

View file

@ -5,12 +5,16 @@
use tokio::{
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
pin,
sync::mpsc,
};
use crate::{
rpc::{self, MaybeSync, Serialization},
util::errors::InvalidRpcDataError,
util::{
errors::InvalidRpcDataError,
sync::{Barrier, Receivable},
},
};
use std::io;
@ -39,34 +43,38 @@ pub fn new_json_rpc() -> rpc::RpcBuilder<JsonRpcSerializer> {
}
#[allow(dead_code)]
pub async fn start_json_rpc<C: Send + Sync + 'static, S>(
pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
dispatcher: rpc::RpcDispatcher<JsonRpcSerializer, C>,
read: impl AsyncRead + Unpin,
mut write: impl AsyncWrite + Unpin,
mut msg_rx: mpsc::UnboundedReceiver<Vec<u8>>,
mut shutdown_rx: mpsc::UnboundedReceiver<S>,
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let mut read = BufReader::new(read);
let mut read_buf = String::new();
let shutdown_fut = shutdown_rx.wait();
pin!(shutdown_fut);
loop {
tokio::select! {
r = shutdown_rx.recv() => return Ok(r),
r = &mut shutdown_fut => return Ok(r.ok()),
Some(w) = write_rx.recv() => {
write.write_all(&w).await?;
},
Some(w) = msg_rx.recv() => {
Some(w) = msg_rx.recv_msg() => {
write.write_all(&w).await?;
},
n = read.read_line(&mut read_buf) => {
let r = match n {
Ok(0) => return Ok(None),
Ok(n) => dispatcher.dispatch(read_buf[..n].as_bytes()),
Ok(n) => dispatcher.dispatch(read_buf[..n].as_bytes()),
Err(e) => return Err(e)
};
read_buf.truncate(0);
match r {
MaybeSync::Sync(Some(v)) => {
write_tx.send(v).ok();

View file

@ -18,6 +18,8 @@ pub mod tunnels;
pub mod update_service;
pub mod util;
mod rpc;
mod async_pipe;
mod json_rpc;
mod msgpack_rpc;
mod rpc;
mod singleton;

View file

@ -8,6 +8,7 @@ use opentelemetry::{
sdk::trace::{Tracer, TracerProvider},
trace::{SpanBuilder, Tracer as TraitTracer, TracerProvider as TracerProviderTrait},
};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::{env, path::Path, sync::Arc};
use std::{
@ -25,7 +26,7 @@ pub fn next_counter() -> u32 {
}
// Log level
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug)]
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)]
pub enum Level {
Trace = 0,
Debug,

View file

@ -5,12 +5,16 @@
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader},
pin,
sync::mpsc,
};
use crate::{
rpc::{self, MaybeSync, Serialization},
util::errors::{AnyError, InvalidRpcDataError},
util::{
errors::{AnyError, InvalidRpcDataError},
sync::{Barrier, Receivable},
},
};
use std::io;
@ -35,17 +39,20 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
}
#[allow(clippy::read_zero_byte_vec)] // false positive
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S>(
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
dispatcher: rpc::RpcDispatcher<MsgPackSerializer, C>,
read: impl AsyncRead + Unpin,
mut write: impl AsyncWrite + Unpin,
mut msg_rx: mpsc::UnboundedReceiver<Vec<u8>>,
mut shutdown_rx: mpsc::UnboundedReceiver<S>,
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let mut read = BufReader::new(read);
let mut decode_buf = vec![];
let shutdown_fut = shutdown_rx.wait();
pin!(shutdown_fut);
loop {
tokio::select! {
u = read.read_u32() => {
@ -66,16 +73,16 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S>(
});
}
},
r = shutdown_rx.recv() => return Ok(r),
r = &mut shutdown_fut => return Ok(r.ok()),
};
},
Some(m) = write_rx.recv() => {
write.write_all(&m).await?;
},
Some(m) = msg_rx.recv() => {
Some(m) = msg_rx.recv_msg() => {
write.write_all(&m).await?;
},
r = shutdown_rx.recv() => return Ok(r),
r = &mut shutdown_fut => return Ok(r.ok()),
}
write.flush().await?;

View file

@ -204,19 +204,28 @@ pub struct RpcCaller<S: Serialization> {
}
impl<S: Serialization> RpcCaller<S> {
pub fn serialize_notify<M, A>(serializer: &S, method: M, params: A) -> Vec<u8>
where
S: Serialization,
M: AsRef<str> + serde::Serialize,
A: Serialize,
{
serializer.serialize(&FullRequest {
id: None,
method,
params,
})
}
/// Enqueues an outbound call. Returns whether the message was enqueued.
pub fn notify<M, A>(&self, method: M, params: A) -> bool
where
M: Into<String>,
M: AsRef<str> + serde::Serialize,
A: Serialize,
{
let body = self.serializer.serialize(&FullRequest {
id: None,
method: method.into(),
params,
});
self.sender.send(body).is_ok()
self.sender
.send(Self::serialize_notify(&self.serializer, method, params))
.is_ok()
}
/// Enqueues an outbound call, returning its result.
@ -227,7 +236,7 @@ impl<S: Serialization> RpcCaller<S> {
params: A,
) -> oneshot::Receiver<Result<R, ResponseError>>
where
M: Into<String>,
M: AsRef<str> + serde::Serialize,
A: Serialize,
R: DeserializeOwned + Send + 'static,
{
@ -235,7 +244,7 @@ impl<S: Serialization> RpcCaller<S> {
let id = next_message_id();
let body = self.serializer.serialize(&FullRequest {
id: Some(id),
method: method.into(),
method,
params,
});
@ -349,9 +358,9 @@ struct PartialIncoming {
}
#[derive(Serialize)]
pub struct FullRequest<P> {
pub struct FullRequest<M: AsRef<str>, P> {
pub id: Option<u32>,
pub method: String,
pub method: M,
pub params: P,
}

177
cli/src/singleton.rs Normal file
View file

@ -0,0 +1,177 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use serde::{Deserialize, Serialize};
use std::{
fs::{File, OpenOptions},
io::{Seek, SeekFrom, Write},
path::{Path, PathBuf},
time::Duration,
};
use sysinfo::{Pid, PidExt};
use crate::{
async_pipe::{
get_socket_name, get_socket_rw_stream, listen_socket_rw_stream, AsyncPipe,
AsyncPipeListener,
},
util::{
errors::CodeError,
file_lock::{FileLock, Lock, PREFIX_LOCKED_BYTES},
machine::wait_until_process_exits,
},
};
pub struct SingletonServer {
server: AsyncPipeListener,
_lock: FileLock,
}
impl SingletonServer {
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
self.server.accept().await
}
}
pub enum SingletonConnection {
/// This instance got the singleton lock. It started listening on a socket
/// and has the read/write pair. If this gets dropped, the lock is released.
Singleton(SingletonServer),
/// Another instance is a singleton, and this client connected to it.
Client(AsyncPipe),
}
/// Contents of the lock file; the listening socket ID and process ID
/// doing the listening.
#[derive(Deserialize, Serialize)]
struct LockFileMatter {
socket_path: String,
pid: u32,
}
pub async fn acquire_singleton(lock_file: PathBuf) -> Result<SingletonConnection, CodeError> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&lock_file)
.map_err(CodeError::SingletonLockfileOpenFailed)?;
match FileLock::acquire(file) {
Ok(Lock::AlreadyLocked(mut file)) => connect_as_client(&mut file).await,
Ok(Lock::Acquired(lock)) => start_singleton_server(lock).await,
Err(e) => Err(e),
}
}
async fn start_singleton_server(mut lock: FileLock) -> Result<SingletonConnection, CodeError> {
let socket_path = get_socket_name();
let mut vec = Vec::with_capacity(128);
let _ = vec.write(&[0; PREFIX_LOCKED_BYTES]);
let _ = rmp_serde::encode::write(
&mut vec,
&LockFileMatter {
socket_path: socket_path.to_string_lossy().to_string(),
pid: std::process::id(),
},
);
lock.file_mut()
.write_all(&vec)
.map_err(CodeError::SingletonLockfileOpenFailed)?;
let server = listen_socket_rw_stream(&socket_path).await?;
Ok(SingletonConnection::Singleton(SingletonServer {
server,
_lock: lock,
}))
}
const MAX_CLIENT_ATTEMPTS: i32 = 10;
async fn connect_as_client(mut file: &mut File) -> Result<SingletonConnection, CodeError> {
// retry, since someone else could get a lock and we could read it before
// the JSON info was finished writing out
let mut attempt = 0;
loop {
let _ = file.seek(SeekFrom::Start(PREFIX_LOCKED_BYTES as u64));
let r = match rmp_serde::from_read::<_, LockFileMatter>(&mut file) {
Ok(prev) => {
let socket_path = PathBuf::from(prev.socket_path);
tokio::select! {
p = retry_get_socket_rw_stream(&socket_path, 5, Duration::from_millis(500)) => p,
_ = wait_until_process_exits(Pid::from_u32(prev.pid), 500) => Err(CodeError::SingletonLockedProcessExited(prev.pid)),
}
}
Err(e) => Err(CodeError::SingletonLockfileReadFailed(e)),
};
if r.is_ok() || attempt == MAX_CLIENT_ATTEMPTS {
return r.map(SingletonConnection::Client);
}
attempt += 1;
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
async fn retry_get_socket_rw_stream(
path: &Path,
max_tries: usize,
interval: Duration,
) -> Result<AsyncPipe, CodeError> {
for i in 0.. {
match get_socket_rw_stream(path).await {
Ok(s) => return Ok(s),
Err(e) if i == max_tries => return Err(e),
Err(_) => tokio::time::sleep(interval).await,
}
}
unreachable!()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_acquires_singleton() {
let dir = tempfile::tempdir().expect("expected to make temp dir");
let s = acquire_singleton(dir.path().join("lock"))
.await
.expect("expected to acquire");
match s {
SingletonConnection::Singleton(_) => {}
_ => panic!("expected to be singleton"),
}
}
#[tokio::test]
async fn test_acquires_client() {
let dir = tempfile::tempdir().expect("expected to make temp dir");
let lockfile = dir.path().join("lock");
let s1 = acquire_singleton(lockfile.clone())
.await
.expect("expected to acquire1");
match s1 {
SingletonConnection::Singleton(mut l) => tokio::spawn(async move {
l.accept().await.expect("expected to accept");
}),
_ => panic!("expected to be singleton"),
};
let s2 = acquire_singleton(lockfile)
.await
.expect("expected to acquire2");
match s2 {
SingletonConnection::Client(_) => {}
_ => panic!("expected to be client"),
}
}
}

View file

@ -8,6 +8,8 @@ pub mod dev_tunnels;
pub mod legal;
pub mod paths;
pub mod shutdown_signal;
pub mod singleton_client;
pub mod singleton_server;
mod control_server;
mod nosleep;
@ -19,8 +21,6 @@ mod nosleep_macos;
mod nosleep_windows;
mod port_forwarder;
mod protocol;
#[cfg_attr(unix, path = "tunnels/server_bridge_unix.rs")]
#[cfg_attr(windows, path = "tunnels/server_bridge_windows.rs")]
mod server_bridge;
mod server_multiplexer;
mod service;

View file

@ -3,6 +3,7 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use super::paths::{InstalledServer, LastUsedServers, ServerPaths};
use crate::async_pipe::get_socket_name;
use crate::constants::{APPLICATION_NAME, QUALITYLESS_PRODUCT_NAME, QUALITYLESS_SERVER_NAME};
use crate::options::{Quality, TelemetryLevel};
use crate::state::LauncherPaths;
@ -32,7 +33,6 @@ use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::oneshot::Receiver;
use tokio::time::{interval, timeout};
use uuid::Uuid;
lazy_static! {
static ref LISTENING_PORT_RE: Regex =
@ -539,12 +539,7 @@ impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Htt
}
pub async fn listen_on_default_socket(&self) -> Result<SocketCodeServer, AnyError> {
let requested_file = if cfg!(target_os = "windows") {
PathBuf::from(format!(r"\\.\pipe\vscode-server-{}", Uuid::new_v4()))
} else {
std::env::temp_dir().join(format!("vscode-server-{}", Uuid::new_v4()))
};
let requested_file = get_socket_name();
self.listen_on_socket(&requested_file).await
}

View file

@ -2,6 +2,7 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::async_pipe::get_socket_rw_stream;
use crate::constants::{CONTROL_PORT, EDITOR_WEB_URL, QUALITYLESS_SERVER_NAME};
use crate::log;
use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization};
@ -30,7 +31,6 @@ use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::pin;
use tokio::sync::{mpsc, Mutex};
use super::code_server::{
@ -45,7 +45,7 @@ use super::protocol::{
ServerMessageParams, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult,
VersionParams,
};
use super::server_bridge::{get_socket_rw_stream, ServerBridge};
use super::server_bridge::ServerBridge;
use super::server_multiplexer::ServerMultiplexer;
use super::shutdown_signal::ShutdownSignal;
use super::socket_signal::{
@ -155,7 +155,7 @@ pub async fn serve(
launcher_paths: &LauncherPaths,
code_server_args: &CodeServerArgs,
platform: Platform,
shutdown_rx: mpsc::UnboundedReceiver<ShutdownSignal>,
mut shutdown_rx: Barrier<ShutdownSignal>,
) -> Result<ServerTermination, AnyError> {
let mut port = tunnel.add_port_direct(CONTROL_PORT).await?;
print_listening(log, &tunnel.name);
@ -164,12 +164,10 @@ pub async fn serve(
let (tx, mut rx) = mpsc::channel::<ServerSignal>(4);
let (exit_barrier, signal_exit) = new_barrier();
pin!(shutdown_rx);
loop {
tokio::select! {
Some(r) = shutdown_rx.recv() => {
info!(log, "Shutting down: {}", r );
Ok(r) = shutdown_rx.wait() => {
info!(log, "Shutting down: {}", r);
drop(signal_exit);
return Ok(ServerTermination {
respawn: false,

View file

@ -41,7 +41,10 @@ pub fn require_consent(
if accept_server_license_terms {
load.consented = Some(true);
} else if !*IS_INTERACTIVE_CLI {
return Err(MissingLegalConsent("Run this command again with --accept-server-license-terms to indicate your agreement.".to_string())
return Err(MissingLegalConsent(
"Run this command again with --accept-server-license-terms to indicate your agreement."
.to_string(),
)
.into());
} else {
match prompt_yn(prompt) {

View file

@ -4,7 +4,10 @@
*--------------------------------------------------------------------------------------------*/
use std::collections::HashMap;
use crate::{constants::{VSCODE_CLI_VERSION, PROTOCOL_VERSION}, options::Quality};
use crate::{
constants::{PROTOCOL_VERSION, VSCODE_CLI_VERSION},
options::Quality,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Debug)]
@ -154,3 +157,22 @@ impl Default for VersionParams {
}
}
}
pub mod singleton {
use crate::log;
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
pub struct LogMessage<'a> {
pub level: log::Level,
pub prefix: &'a str,
pub message: &'a str,
}
#[derive(Deserialize)]
pub struct LogMessageOwned {
pub level: log::Level,
pub prefix: String,
pub message: String,
}
}

View file

@ -2,36 +2,19 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::path::Path;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{unix::OwnedWriteHalf, UnixStream},
};
use crate::util::errors::{wrap, AnyError};
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
use crate::{
async_pipe::{get_socket_rw_stream, socket_stream_split, AsyncPipeWriteHalf},
util::errors::AnyError,
};
use std::path::Path;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub struct ServerBridge {
write: OwnedWriteHalf,
write: AsyncPipeWriteHalf,
decoder: ClientMessageDecoder,
}
pub async fn get_socket_rw_stream(path: &Path) -> Result<UnixStream, AnyError> {
let s = UnixStream::connect(path).await.map_err(|e| {
wrap(
e,
format!(
"error connecting to vscode server socket in {}",
path.display()
),
)
})?;
Ok(s)
}
const BUFFER_SIZE: usize = 65536;
impl ServerBridge {
@ -41,7 +24,7 @@ impl ServerBridge {
decoder: ClientMessageDecoder,
) -> Result<Self, AnyError> {
let stream = get_socket_rw_stream(path).await?;
let (mut read, write) = stream.into_split();
let (mut read, write) = socket_stream_split(stream);
tokio::spawn(async move {
let mut read_buf = vec![0; BUFFER_SIZE];

View file

@ -1,132 +0,0 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{path::Path, time::Duration};
use tokio::{
io::{self, Interest},
net::windows::named_pipe::{ClientOptions, NamedPipeClient},
sync::mpsc,
time::sleep,
};
use crate::util::errors::{wrap, AnyError};
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
pub struct ServerBridge {
write_tx: mpsc::Sender<Vec<u8>>,
decoder: ClientMessageDecoder,
}
const BUFFER_SIZE: usize = 65536;
pub async fn get_socket_rw_stream(path: &Path) -> Result<NamedPipeClient, AnyError> {
// Tokio says we can need to try in a loop. Do so.
// https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
let client = loop {
match ClientOptions::new().open(path) {
Ok(client) => break client,
// ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499-
Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await,
Err(e) => {
return Err(AnyError::WrappedError(wrap(
e,
format!(
"error connecting to vscode server socket in {}",
path.display()
),
)))
}
}
};
Ok(client)
}
impl ServerBridge {
pub async fn new(
path: &Path,
mut target: ServerMessageSink,
decoder: ClientMessageDecoder,
) -> Result<Self, AnyError> {
let client = get_socket_rw_stream(path).await?;
let (write_tx, mut write_rx) = mpsc::channel(4);
tokio::spawn(async move {
let mut read_buf = vec![0; BUFFER_SIZE];
let mut pending_recv: Option<Vec<u8>> = None;
// See https://docs.rs/tokio/1.17.0/tokio/net/windows/named_pipe/struct.NamedPipeClient.html#method.ready
// With additional complications. If there's nothing queued to write, we wait for the
// pipe to be readable, or for something to come in. If there is something to
// write, wait until the pipe is either readable or writable.
loop {
let ready_result = if pending_recv.is_none() {
tokio::select! {
msg = write_rx.recv() => match msg {
Some(msg) => {
pending_recv = Some(msg);
client.ready(Interest::READABLE | Interest::WRITABLE).await
},
None => return
},
r = client.ready(Interest::READABLE) => r,
}
} else {
client.ready(Interest::READABLE | Interest::WRITABLE).await
};
let ready = match ready_result {
Ok(r) => r,
Err(_) => return,
};
if ready.is_readable() {
match client.try_read(&mut read_buf) {
Ok(0) => return, // EOF
Ok(s) => {
let send = target.server_message(&read_buf[..s]).await;
if send.is_err() {
return;
}
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
}
Err(_) => return,
}
}
if let Some(msg) = &pending_recv {
if ready.is_writable() {
match client.try_write(msg) {
Ok(n) if n == msg.len() => pending_recv = None,
Ok(n) => pending_recv = Some(msg[n..].to_vec()),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
continue;
}
Err(_) => return,
}
}
}
}
});
Ok(ServerBridge { write_tx, decoder })
}
pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
let dec = self.decoder.decode(&b)?;
if !dec.is_empty() {
self.write_tx.send(dec.to_vec()).await.ok();
}
Ok(())
}
pub async fn close(self) -> std::io::Result<()> {
drop(self.write_tx);
Ok(())
}
}

View file

@ -6,15 +6,12 @@
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use tokio::sync::mpsc;
use crate::log;
use crate::state::LauncherPaths;
use crate::util::errors::{wrap, AnyError};
use crate::util::io::{tailf, TailEvent};
use super::shutdown_signal::ShutdownSignal;
pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log";
#[async_trait]
@ -23,7 +20,6 @@ pub trait ServiceContainer: Send {
&mut self,
log: log::Logger,
launcher_paths: LauncherPaths,
shutdown_rx: mpsc::UnboundedReceiver<ShutdownSignal>,
) -> Result<(), AnyError>;
}

View file

@ -10,7 +10,6 @@ use std::{
process::Command,
};
use super::shutdown_signal::ShutdownSignal;
use async_trait::async_trait;
use zbus::{dbus_proxy, zvariant, Connection};
@ -119,8 +118,7 @@ impl ServiceManager for SystemdService {
launcher_paths: crate::state::LauncherPaths,
mut handle: impl 'static + super::ServiceContainer,
) -> Result<(), crate::util::errors::AnyError> {
let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]);
handle.run_service(self.log, launcher_paths, rx).await
handle.run_service(self.log, launcher_paths).await
}
async fn show_logs(&self) -> Result<(), AnyError> {

View file

@ -9,7 +9,6 @@ use std::{
path::{Path, PathBuf},
};
use super::shutdown_signal::ShutdownSignal;
use async_trait::async_trait;
use crate::{
@ -73,8 +72,7 @@ impl ServiceManager for LaunchdService {
launcher_paths: crate::state::LauncherPaths,
mut handle: impl 'static + super::ServiceContainer,
) -> Result<(), crate::util::errors::AnyError> {
let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]);
handle.run_service(self.log, launcher_paths, rx).await
handle.run_service(self.log, launcher_paths).await
}
async fn unregister(&self) -> Result<(), crate::util::errors::AnyError> {

View file

@ -17,7 +17,6 @@ use crate::{
constants::TUNNEL_ACTIVITY_NAME,
log,
state::LauncherPaths,
tunnels::shutdown_signal::ShutdownSignal,
util::errors::{wrap, wrapdbg, AnyError},
};
@ -90,8 +89,7 @@ impl CliServiceManager for WindowsService {
launcher_paths: LauncherPaths,
mut handle: impl 'static + ServiceContainer,
) -> Result<(), AnyError> {
let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]);
handle.run_service(self.log, launcher_paths, rx).await
handle.run_service(self.log, launcher_paths).await
}
async fn unregister(&self) -> Result<(), AnyError> {

View file

@ -3,16 +3,21 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{fmt, time::Duration};
use std::fmt;
use sysinfo::Pid;
use sysinfo::{Pid, SystemExt};
use tokio::{sync::mpsc, time::sleep};
use crate::util::{
machine::wait_until_process_exits,
sync::{new_barrier, Barrier},
};
/// Describes the signal to manully stop the server
#[derive(Copy, Clone)]
pub enum ShutdownSignal {
CtrlC,
ParentProcessKilled(Pid),
ServiceStopped,
RpcShutdownRequested,
}
impl fmt::Display for ShutdownSignal {
@ -23,41 +28,57 @@ impl fmt::Display for ShutdownSignal {
write!(f, "Parent process {} no longer exists", p)
}
ShutdownSignal::ServiceStopped => write!(f, "Service stopped"),
ShutdownSignal::RpcShutdownRequested => write!(f, "RPC client requested shutdown"),
}
}
}
impl ShutdownSignal {
pub enum ShutdownRequest {
CtrlC,
ParentProcessKilled(Pid),
RpcShutdownRequested(Barrier<()>),
Derived(Barrier<ShutdownSignal>),
}
impl ShutdownRequest {
/// Creates a receiver channel sent to once any of the signals are received.
/// Note: does not handle ServiceStopped
pub fn create_rx(signals: &[ShutdownSignal]) -> mpsc::UnboundedReceiver<ShutdownSignal> {
let (tx, rx) = mpsc::unbounded_channel();
for signal in signals {
let tx = tx.clone();
pub fn create_rx(
signals: impl IntoIterator<Item = ShutdownRequest>,
) -> Barrier<ShutdownSignal> {
let (barrier, opener) = new_barrier();
for signal in signals.into_iter() {
let opener = opener.clone();
match signal {
ShutdownSignal::CtrlC => {
ShutdownRequest::CtrlC => {
let ctrl_c = tokio::signal::ctrl_c();
tokio::spawn(async move {
ctrl_c.await.ok();
tx.send(ShutdownSignal::CtrlC).ok();
opener.open(ShutdownSignal::CtrlC)
});
}
ShutdownSignal::ParentProcessKilled(pid) => {
let pid = *pid;
let tx = tx.clone();
ShutdownRequest::ParentProcessKilled(pid) => {
tokio::spawn(async move {
let mut s = sysinfo::System::new();
while s.refresh_process(pid) {
sleep(Duration::from_millis(2000)).await;
}
tx.send(ShutdownSignal::ParentProcessKilled(pid)).ok();
wait_until_process_exits(pid, 2000).await;
opener.open(ShutdownSignal::ParentProcessKilled(pid))
});
}
ShutdownSignal::ServiceStopped => {
unreachable!("Cannot use ServiceStopped in ShutdownSignal::create_rx");
ShutdownRequest::RpcShutdownRequested(mut rx) => {
tokio::spawn(async move {
let _ = rx.wait().await;
opener.open(ShutdownSignal::RpcShutdownRequested)
});
}
ShutdownRequest::Derived(mut rx) => {
tokio::spawn(async move {
if let Ok(s) = rx.wait().await {
opener.open(s);
}
});
}
}
}
rx
barrier
}
}

View file

@ -0,0 +1,45 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::{
async_pipe::{socket_stream_split, AsyncPipe},
json_rpc::{new_json_rpc, start_json_rpc},
log,
util::sync::Barrier,
};
use super::{protocol, shutdown_signal::ShutdownSignal};
pub struct SingletonClientArgs {
pub log: log::Logger,
pub stream: AsyncPipe,
pub shutdown: Barrier<ShutdownSignal>,
}
struct SingletonServerContext {
log: log::Logger,
}
pub async fn start_singleton_client(args: SingletonClientArgs) {
let rpc = new_json_rpc();
debug!(
args.log,
"An existing tunnel is running on this machine, connecting to it..."
);
let mut rpc = rpc.methods(SingletonServerContext {
log: args.log.clone(),
});
rpc.register_sync("log", |log: protocol::singleton::LogMessageOwned, c| {
c.log
.emit(log.level, &format!("{}: {}", log.prefix, log.message));
Ok(())
});
let (read, write) = socket_stream_split(args.stream);
let _ = start_json_rpc(rpc.build(args.log), read, write, (), args.shutdown.clone()).await;
}

View file

@ -0,0 +1,185 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::sync::{Arc, Mutex};
use super::{
code_server::CodeServerArgs,
control_server::ServerTermination,
dev_tunnels::ActiveTunnel,
protocol,
shutdown_signal::{ShutdownRequest, ShutdownSignal},
};
use crate::{
async_pipe::socket_stream_split,
json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer},
log,
rpc::{RpcCaller, RpcDispatcher},
singleton::SingletonServer,
state::LauncherPaths,
update_service::Platform,
util::{
errors::{AnyError, CodeError},
ring_buffer::RingBuffer,
sync::{new_barrier, Barrier, BarrierOpener, ConcatReceivable},
},
};
use tokio::{
pin,
sync::{broadcast, mpsc},
};
pub struct SingletonServerArgs {
pub log: log::Logger,
pub tunnel: ActiveTunnel,
pub paths: LauncherPaths,
pub code_server_args: CodeServerArgs,
pub platform: Platform,
pub server: SingletonServer,
pub shutdown: Barrier<ShutdownSignal>,
pub log_broadcast: BroadcastLogSink,
}
#[derive(Clone)]
struct SingletonServerContext {
shutdown: BarrierOpener<()>,
}
pub async fn start_singleton_server(
mut args: SingletonServerArgs,
) -> Result<ServerTermination, AnyError> {
let (shutdown_rx, shutdown_tx) = new_barrier();
let shutdown_rx = ShutdownRequest::create_rx([
ShutdownRequest::RpcShutdownRequested(shutdown_rx),
ShutdownRequest::Derived(args.shutdown),
]);
let rpc = new_json_rpc();
let mut rpc = rpc.methods(SingletonServerContext {
shutdown: shutdown_tx,
});
rpc.register_sync("shutdown", |_: protocol::EmptyObject, ctx| {
ctx.shutdown.open(());
Ok(())
});
let (r1, r2) = tokio::join!(
serve_singleton_rpc(
args.log_broadcast,
&mut args.server,
rpc.build(args.log.clone()),
shutdown_rx.clone(),
),
super::serve(
&args.log,
args.tunnel,
&args.paths,
&args.code_server_args,
args.platform,
shutdown_rx,
),
);
r1?;
r2
}
async fn serve_singleton_rpc<C: Clone + Send + Sync + 'static>(
log_broadcast: BroadcastLogSink,
server: &mut SingletonServer,
dispatcher: RpcDispatcher<JsonRpcSerializer, C>,
shutdown_rx: Barrier<ShutdownSignal>,
) -> Result<(), CodeError> {
let mut own_shutdown = shutdown_rx.clone();
let shutdown_fut = own_shutdown.wait();
pin!(shutdown_fut);
loop {
let cnx = tokio::select! {
c = server.accept() => c?,
_ = &mut shutdown_fut => return Ok(()),
};
let (read, write) = socket_stream_split(cnx);
let dispatcher = dispatcher.clone();
let msg_rx = log_broadcast.replay_and_subscribe();
let shutdown_rx = shutdown_rx.clone();
tokio::spawn(async move {
let _ = start_json_rpc(dispatcher.clone(), read, write, msg_rx, shutdown_rx).await;
});
}
}
/// Log sink that can broadcast and replay log events. Used for transmitting
/// logs from the singleton to all clients. This should be created and injected
/// into other services, like the tunnel, before `start_singleton_server`
/// is called.
#[derive(Clone)]
pub struct BroadcastLogSink {
recent: Arc<Mutex<RingBuffer<Vec<u8>>>>,
tx: broadcast::Sender<Vec<u8>>,
}
impl Default for BroadcastLogSink {
fn default() -> Self {
Self::new()
}
}
impl BroadcastLogSink {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(64);
Self {
tx,
recent: Arc::new(Mutex::new(RingBuffer::new(50))),
}
}
fn replay_and_subscribe(
&self,
) -> ConcatReceivable<Vec<u8>, mpsc::UnboundedReceiver<Vec<u8>>, broadcast::Receiver<Vec<u8>>> {
let (log_replay_tx, log_replay_rx) = mpsc::unbounded_channel();
for log in self.recent.lock().unwrap().iter() {
let _ = log_replay_tx.send(log.clone());
}
let _ = log_replay_tx.send(RpcCaller::serialize_notify(
&JsonRpcSerializer {},
"log",
protocol::singleton::LogMessage {
level: log::Level::Info,
prefix: "",
message: "Connected to an existing tunnel process running on this machined.",
},
));
ConcatReceivable::new(log_replay_rx, self.tx.subscribe())
}
}
impl log::LogSink for BroadcastLogSink {
fn write_log(&self, level: log::Level, prefix: &str, message: &str) {
let s = JsonRpcSerializer {};
let serialized = RpcCaller::serialize_notify(
&s,
"log",
protocol::singleton::LogMessage {
level,
prefix,
message,
},
);
let _ = self.tx.send(serialized.clone());
self.recent.lock().unwrap().push(serialized);
}
fn write_result(&self, message: &str) {
self.write_log(log::Level::Info, "", message);
}
}

View file

@ -16,6 +16,7 @@ use crate::{
wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError,
},
http::ReqwestSimpleHttp,
sync::Barrier,
},
};
@ -69,7 +70,7 @@ pub async fn serve_wsl(
code_server_args: CodeServerArgs,
platform: Platform,
http: reqwest::Client,
shutdown_rx: mpsc::UnboundedReceiver<ShutdownSignal>,
shutdown_rx: Barrier<ShutdownSignal>,
) -> Result<i32, AnyError> {
let (caller_tx, caller_rx) = mpsc::unbounded_channel();
let mut rpc = new_msgpack_rpc();

View file

@ -12,8 +12,10 @@ pub mod input;
pub mod io;
pub mod machine;
pub mod prereqs;
pub mod ring_buffer;
pub mod sync;
pub use is_integrated::*;
pub mod file_lock;
#[cfg(target_os = "linux")]
pub mod tar;

View file

@ -2,11 +2,11 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::fmt::Display;
use crate::constants::{
APPLICATION_NAME, CONTROL_PORT, DOCUMENTATION_URL, QUALITYLESS_PRODUCT_NAME,
};
use std::fmt::Display;
use thiserror::Error;
// Wraps another error with additional info.
#[derive(Debug, Clone)]
@ -475,6 +475,22 @@ macro_rules! makeAnyError {
};
}
/// Internal errors in the VS Code CLI.
/// Note: other error should be migrated to this type gradually
#[derive(Error, Debug)]
pub enum CodeError {
#[error("could not connect to socket/pipe")]
AsyncPipeFailed(std::io::Error),
#[error("could not listen on socket/pipe")]
AsyncPipeListenerFailed(std::io::Error),
#[error("could not create singleton lock file")]
SingletonLockfileOpenFailed(std::io::Error),
#[error("could not read singleton lock file")]
SingletonLockfileReadFailed(rmp_serde::decode::Error),
#[error("the process holding the singleton lock file exited")]
SingletonLockedProcessExited(u32),
}
makeAnyError!(
MissingLegalConsent,
MismatchConnectionToken,
@ -505,7 +521,8 @@ makeAnyError!(
MissingHomeDirectory,
CommandFailed,
OAuthError,
InvalidRpcDataError
InvalidRpcDataError,
CodeError
);
impl From<reqwest::Error> for AnyError {

125
cli/src/util/file_lock.rs Normal file
View file

@ -0,0 +1,125 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::util::errors::CodeError;
use std::{fs::File, io};
pub struct FileLock {
file: File,
#[cfg(windows)]
overlapped: winapi::um::minwinbase::OVERLAPPED,
}
#[cfg(windows)] // overlapped is thread-safe, mark it so with this
unsafe impl Send for FileLock {}
pub enum Lock {
Acquired(FileLock),
AlreadyLocked(File),
}
/// Number of locked bytes in the file. On Windows, locking prevents reads,
/// but consumers of the lock may still want to read what the locking file
/// as written. Thus, only PREFIX_LOCKED_BYTES are locked, and any globally-
/// readable content should be written after the prefix.
#[cfg(windows)]
pub const PREFIX_LOCKED_BYTES: usize = 1;
#[cfg(unix)]
pub const PREFIX_LOCKED_BYTES: usize = 0;
impl FileLock {
#[cfg(windows)]
pub fn acquire(file: File) -> Result<Lock, CodeError> {
use std::os::windows::prelude::AsRawHandle;
use winapi::{
shared::winerror::{ERROR_IO_PENDING, ERROR_LOCK_VIOLATION},
um::{
fileapi::LockFileEx,
minwinbase::{LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY},
},
};
let handle = file.as_raw_handle();
let (overlapped, ok) = unsafe {
let mut overlapped = std::mem::zeroed();
let ok = LockFileEx(
handle,
LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY,
0,
PREFIX_LOCKED_BYTES as u32,
0,
&mut overlapped,
);
(overlapped, ok)
};
if ok != 0 {
return Ok(Lock::Acquired(Self { file, overlapped }));
}
let err = io::Error::last_os_error();
let raw = err.raw_os_error();
// docs report it should return ERROR_IO_PENDING, but in my testing it actually
// returns ERROR_LOCK_VIOLATION. Or maybe winapi is wrong?
if raw == Some(ERROR_IO_PENDING as i32) || raw == Some(ERROR_LOCK_VIOLATION as i32) {
return Ok(Lock::AlreadyLocked(file));
}
Err(CodeError::SingletonLockfileOpenFailed(err))
}
#[cfg(unix)]
pub fn acquire(file: File) -> Result<Lock, CodeError> {
use std::os::unix::io::AsRawFd;
let fd = file.as_raw_fd();
let res = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
if res == 0 {
return Ok(Lock::Acquired(Self { file }));
}
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(Lock::AlreadyLocked(file));
}
Err(CodeError::SingletonLockfileOpenFailed(err))
}
pub fn file(&self) -> &File {
&self.file
}
pub fn file_mut(&mut self) -> &mut File {
&mut self.file
}
}
impl Drop for FileLock {
#[cfg(windows)]
fn drop(&mut self) {
use std::os::windows::prelude::AsRawHandle;
use winapi::um::fileapi::UnlockFileEx;
unsafe {
UnlockFileEx(
self.file.as_raw_handle(),
0,
u32::MAX,
u32::MAX,
&mut self.overlapped,
)
};
}
#[cfg(unix)]
fn drop(&mut self) {
use std::os::unix::io::AsRawFd;
unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) };
}
}

View file

@ -15,6 +15,8 @@ use tokio::{
time::sleep,
};
use super::ring_buffer::RingBuffer;
pub trait ReportCopyProgress {
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64);
}
@ -132,8 +134,7 @@ pub fn tailf(file: File, n: usize) -> mpsc::UnboundedReceiver<TailEvent> {
// Read the initial "n" lines back from the request. initial_lines
// is a small ring buffer.
let mut initial_lines = Vec::with_capacity(n);
let mut initial_lines_i = 0;
let mut initial_lines = RingBuffer::new(n);
loop {
let mut line = String::new();
let bytes_read = match reader.read_line(&mut line) {
@ -151,26 +152,11 @@ pub fn tailf(file: File, n: usize) -> mpsc::UnboundedReceiver<TailEvent> {
}
pos += bytes_read as u64;
if initial_lines.len() < initial_lines.capacity() {
initial_lines.push(line)
} else {
initial_lines[initial_lines_i] = line;
}
initial_lines_i = (initial_lines_i + 1) % n;
initial_lines.push(line);
}
// remove tail lines...
if initial_lines_i < initial_lines.len() {
for line in initial_lines.drain((initial_lines_i)..) {
tx.send(TailEvent::Line(line)).ok();
}
}
// then the remaining lines
if !initial_lines.is_empty() {
for line in initial_lines.drain(0..) {
tx.send(TailEvent::Line(line)).ok();
}
for line in initial_lines.into_iter() {
tx.send(TailEvent::Line(line)).ok();
}
// now spawn the poll process to keep reading new lines

View file

@ -3,7 +3,7 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::path::Path;
use std::{path::Path, time::Duration};
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
pub fn process_at_path_exists(pid: u32, name: &Path) -> bool {
@ -29,6 +29,14 @@ pub fn process_exists(pid: u32) -> bool {
sys.refresh_process(Pid::from_u32(pid))
}
pub async fn wait_until_process_exits(pid: Pid, poll_ms: u64) {
let mut s = System::new();
let duration = Duration::from_millis(poll_ms);
while s.refresh_process(pid) {
tokio::time::sleep(duration).await;
}
}
pub fn find_running_process(name: &Path) -> Option<u32> {
let mut sys = System::new();
sys.refresh_processes();

142
cli/src/util/ring_buffer.rs Normal file
View file

@ -0,0 +1,142 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
pub struct RingBuffer<T> {
data: Vec<T>,
i: usize,
}
impl<T> RingBuffer<T> {
pub fn new(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
i: 0,
}
}
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_full(&self) -> bool {
self.data.len() == self.data.capacity()
}
pub fn is_empty(&self) -> bool {
self.data.len() == 0
}
pub fn push(&mut self, value: T) {
if self.data.len() == self.data.capacity() {
self.data[self.i] = value;
} else {
self.data.push(value);
}
self.i = (self.i + 1) % self.data.capacity();
}
pub fn iter(&self) -> RingBufferIter<'_, T> {
RingBufferIter {
index: 0,
buffer: self,
}
}
}
impl<T: Default> IntoIterator for RingBuffer<T> {
type Item = T;
type IntoIter = OwnedRingBufferIter<T>;
fn into_iter(self) -> OwnedRingBufferIter<T>
where
T: Default,
{
OwnedRingBufferIter {
index: 0,
buffer: self,
}
}
}
pub struct OwnedRingBufferIter<T: Default> {
buffer: RingBuffer<T>,
index: usize,
}
impl<T: Default> Iterator for OwnedRingBufferIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.index == self.buffer.len() {
return None;
}
let ii = (self.index + self.buffer.i) % self.buffer.len();
let item = std::mem::take(&mut self.buffer.data[ii]);
self.index += 1;
Some(item)
}
}
pub struct RingBufferIter<'a, T> {
buffer: &'a RingBuffer<T>,
index: usize,
}
impl<'a, T> Iterator for RingBufferIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.index == self.buffer.len() {
return None;
}
let ii = (self.index + self.buffer.i) % self.buffer.len();
let item = &self.buffer.data[ii];
self.index += 1;
Some(item)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inserts() {
let mut rb = RingBuffer::new(3);
assert_eq!(rb.capacity(), 3);
assert!(!rb.is_full());
assert_eq!(rb.len(), 0);
assert_eq!(rb.iter().copied().count(), 0);
rb.push(1);
assert!(!rb.is_full());
assert_eq!(rb.len(), 1);
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![1]);
rb.push(2);
assert!(!rb.is_full());
assert_eq!(rb.len(), 2);
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![1, 2]);
rb.push(3);
assert!(rb.is_full());
assert_eq!(rb.len(), 3);
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![1, 2, 3]);
rb.push(4);
assert!(rb.is_full());
assert_eq!(rb.len(), 3);
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![2, 3, 4]);
assert_eq!(rb.into_iter().collect::<Vec<i32>>(), vec![2, 3, 4]);
}
}

View file

@ -2,38 +2,53 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use tokio::sync::watch::{
self,
error::{RecvError, SendError},
use async_trait::async_trait;
use std::{marker::PhantomData, sync::Arc};
use tokio::sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
};
#[derive(Clone)]
pub struct Barrier<T>(watch::Receiver<Option<T>>)
where
T: Copy;
T: Clone;
impl<T> Barrier<T>
where
T: Copy,
T: Clone,
{
/// Waits for the barrier to be closed, returning a value if one was sent.
pub async fn wait(&mut self) -> Result<T, RecvError> {
loop {
self.0.changed().await?;
if let Some(v) = *(self.0.borrow()) {
if let Some(v) = self.0.borrow().clone() {
return Ok(v);
}
}
}
/// Gets whether the barrier is currently open
pub fn is_open(&self) -> bool {
self.0.borrow().is_some()
}
}
pub struct BarrierOpener<T>(watch::Sender<Option<T>>);
#[derive(Clone)]
pub struct BarrierOpener<T: Clone>(Arc<watch::Sender<Option<T>>>);
impl<T> BarrierOpener<T> {
/// Closes the barrier.
pub fn open(self, value: T) -> Result<(), SendError<Option<T>>> {
self.0.send(Some(value))
impl<T: Clone> BarrierOpener<T> {
/// Opens the barrier.
pub fn open(&self, value: T) {
self.0.send_if_modified(|v| {
if v.is_none() {
*v = Some(value);
true
} else {
false
}
});
}
}
@ -44,7 +59,119 @@ where
T: Copy,
{
let (closed_tx, closed_rx) = watch::channel(None);
(Barrier(closed_rx), BarrierOpener(closed_tx))
(Barrier(closed_rx), BarrierOpener(Arc::new(closed_tx)))
}
/// Type that can receive messages in an async way.
#[async_trait]
pub trait Receivable<T> {
async fn recv_msg(&mut self) -> Option<T>;
}
// todo: ideally we would use an Arc in the broadcast::Receiver to avoid having
// to clone bytes everywhere, requires updating rpc consumers as well.
#[async_trait]
impl<T: Clone + Send> Receivable<T> for broadcast::Receiver<T> {
async fn recv_msg(&mut self) -> Option<T> {
loop {
match self.recv().await {
Ok(v) => return Some(v),
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => return None,
}
}
}
}
#[async_trait]
impl<T: Send> Receivable<T> for mpsc::UnboundedReceiver<T> {
async fn recv_msg(&mut self) -> Option<T> {
self.recv().await
}
}
#[async_trait]
impl<T: Send> Receivable<T> for () {
async fn recv_msg(&mut self) -> Option<T> {
futures::future::pending().await
}
}
pub struct ConcatReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
left: Option<A>,
right: B,
_marker: PhantomData<T>,
}
impl<T: Send, A: Receivable<T>, B: Receivable<T>> ConcatReceivable<T, A, B> {
pub fn new(left: A, right: B) -> Self {
Self {
left: Some(left),
right,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
for ConcatReceivable<T, A, B>
{
async fn recv_msg(&mut self) -> Option<T> {
if let Some(left) = &mut self.left {
match left.recv_msg().await {
Some(v) => return Some(v),
None => {
self.left = None;
}
}
}
return self.right.recv_msg().await;
}
}
pub struct MergedReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
left: Option<A>,
right: Option<B>,
_marker: PhantomData<T>,
}
impl<T: Send, A: Receivable<T>, B: Receivable<T>> MergedReceivable<T, A, B> {
pub fn new(left: A, right: B) -> Self {
Self {
left: Some(left),
right: Some(right),
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
for MergedReceivable<T, A, B>
{
async fn recv_msg(&mut self) -> Option<T> {
loop {
match (&mut self.left, &mut self.right) {
(Some(left), Some(right)) => {
tokio::select! {
left = left.recv_msg() => match left {
Some(v) => return Some(v),
None => { self.left = None; continue; },
},
right = right.recv_msg() => match right {
Some(v) => return Some(v),
None => { self.right = None; continue; },
},
}
}
(Some(a), None) => break a.recv_msg().await,
(None, Some(b)) => break b.recv_msg().await,
(None, None) => break None,
}
}
}
}
#[cfg(test)]
@ -60,7 +187,7 @@ mod tests {
tx.send(barrier.wait().await.unwrap()).unwrap();
});
opener.open(42).unwrap();
opener.open(42);
assert!(rx.await.unwrap() == 42);
}
@ -71,7 +198,7 @@ mod tests {
let (tx1, rx1) = tokio::sync::oneshot::channel::<u32>();
let (tx2, rx2) = tokio::sync::oneshot::channel::<u32>();
opener.open(42).unwrap();
opener.open(42);
let mut b1 = barrier.clone();
tokio::spawn(async move {
tx1.send(b1.wait().await.unwrap()).unwrap();