initial import of code-tunnel/cli

This commit is contained in:
Connor Peet 2022-09-19 13:29:04 -07:00
parent 831be46050
commit cc7a21cfdf
No known key found for this signature in database
GPG key ID: CF8FD2EA0DBC61BD
48 changed files with 12030 additions and 0 deletions

20
src/cli/CONTRIBUTING.md Normal file
View file

@ -0,0 +1,20 @@
# Setup
0. Clone, and then run `git submodule update --init --recursive`
1. Get the extensions: [rust-analyzer](https://marketplace.visualstudio.com/items?itemName=matklad.rust-analyzer) and [CodeLLDB](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb)
2. Ensure your workspace is set to the `launcher` folder being the root.
## Building the CLI on Windows
For the moment, we require OpenSSL on Windows, where it is not usually installed by default. To install it:
1. Install (clone) vcpkg [using their instructions](https://github.com/Microsoft/vcpkg#quick-start-windows)
1. Add the location of the `vcpkg` directory to your system or user PATH.
1. Run`vcpkg install openssl:x64-windows-static-md` (after restarting your terminal for PATH changes to apply)
1. You should be able to then `cargo build` successfully
OpenSSL is needed for the key exchange we do when forwarding Basis tunnels. When all interested Basis clients support ED25519, we would be able to solely use libsodium. At the time of writing however, there is [no active development](https://chromestatus.com/feature/4913922408710144) on this in Chromium.
# Debug
1. You can use the Debug tasks already configured to run the launcher.

3065
src/cli/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

65
src/cli/Cargo.toml Normal file
View file

@ -0,0 +1,65 @@
[package]
name = "code-cli"
version = "0.1.0"
edition = "2021"
default-run = "code"
[lib]
name = "cli"
path = "src/lib.rs"
[[bin]]
name = "code-tunnel"
[[bin]]
name = "code"
[dependencies]
futures = "0.3"
clap = { version = "3.0", features = ["derive", "env"] }
open = { version = "2.1.0" }
reqwest = { version = "0.11.9", default-features = false, features = ["json", "stream", "native-tls-vendored"] }
tokio = { version = "1.20", features = ["full"] }
tokio-util = { version = "0.7", features = ["compat"] }
flate2 = { version = "1.0.22" }
zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] }
regex = { version = "1.5.5" }
lazy_static = { version = "1.4.0" }
sysinfo = { version = "0.23.5" }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
rmp-serde = "1.0"
uuid = { version = "0.8.2", features = ["serde", "v4"] }
dirs = "4.0.0"
rand = "0.8.5"
atty = "0.2.14"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-application-insights = { version = "0.22.0", features = ["reqwest-client-vendored-tls"] }
serde_bytes = "0.11.5"
chrono = { version = "0.4", features = ["serde", "rustc-serialize"] }
gethostname = "0.2.3"
libc = "0.2"
tunnels = { git = "https://github.com/connor4312/dev-tunnels", branch = "host-relay", features = ["connections", "vendored-openssl"] }
keyring = "1.1"
dialoguer = "0.10"
hyper = "0.14"
indicatif = "0.16"
tempfile = "3.3"
clap_lex = "0.2"
url = "2.3"
async-trait = "0.1"
log = "0.4"
[target.'cfg(windows)'.dependencies]
windows-service = "0.5"
[target.'cfg(target_os = "linux")'.dependencies]
tar = { version = "0.4" }
[profile.release]
strip = true
lto = true
codegen-units = 1
[features]
vscode-encrypt = []

8
src/cli/Cross.toml Normal file
View file

@ -0,0 +1,8 @@
[build.env]
passthrough = [
"LAUNCHER_VERSION",
"LAUNCHER_ASSET_NAME",
]
[target.aarch64-unknown-linux-gnu]
image = "microsoft/vscode-server-launcher-xbuild:aarch64"

57
src/cli/build.rs Normal file
View file

@ -0,0 +1,57 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
const FILE_HEADER: &[u8] = b"/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/";
use std::{env, fs, io, path::PathBuf, process};
fn main() {
let files = enumerate_source_files().expect("expected to enumerate files");
ensure_file_headers(&files).expect("expected to ensure file headers");
}
fn ensure_file_headers(files: &[PathBuf]) -> Result<(), io::Error> {
let mut ok = true;
for file in files {
let contents = fs::read(file)?;
if !contents.starts_with(FILE_HEADER) {
eprintln!("File missing copyright header: {}", file.display());
ok = false;
}
}
if !ok {
process::exit(1);
}
Ok(())
}
/// Gets all "rs" files in the source directory
fn enumerate_source_files() -> Result<Vec<PathBuf>, io::Error> {
let mut files = vec![];
let mut queue = vec![];
let current_dir = env::current_dir()?.join("src");
queue.push(current_dir);
while !queue.is_empty() {
for entry in fs::read_dir(queue.pop().unwrap())? {
let entry = entry?;
let ftype = entry.file_type()?;
if ftype.is_dir() {
queue.push(entry.path());
} else if ftype.is_file() && entry.file_name().to_string_lossy().ends_with(".rs") {
files.push(entry.path());
}
}
}
Ok(files)
}

594
src/cli/src/auth.rs Normal file
View file

@ -0,0 +1,594 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::{
constants::get_default_user_agent,
info, log,
state::{LauncherPaths, PersistedState},
trace,
util::{
errors::{wrap, AnyError, RefreshTokenNotAvailableError, StatusError, WrappedError},
input::prompt_options,
},
warning,
};
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use gethostname::gethostname;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{cell::Cell, fmt::Display, path::PathBuf, sync::Arc};
use tokio::time::sleep;
use tunnels::{
contracts::PROD_FIRST_PARTY_APP_ID,
management::{Authorization, AuthorizationProvider, HttpError},
};
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
message: Option<String>,
verification_uri: String,
expires_in: i64,
}
#[derive(Deserialize)]
struct AuthenticationResponse {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<i64>,
}
#[derive(clap::ArgEnum, Serialize, Deserialize, Debug, Clone, Copy)]
pub enum AuthProvider {
Microsoft,
Github,
}
impl Display for AuthProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthProvider::Microsoft => write!(f, "Microsoft Account"),
AuthProvider::Github => write!(f, "Github Account"),
}
}
}
impl AuthProvider {
pub fn client_id(&self) -> &'static str {
match self {
AuthProvider::Microsoft => "aebc6443-996d-45c2-90f0-388ff96faa56",
AuthProvider::Github => "01ab8ac9400c4e429b23",
}
}
pub fn code_uri(&self) -> &'static str {
match self {
AuthProvider::Microsoft => {
"https://login.microsoftonline.com/common/oauth2/v2.0/devicecode"
}
AuthProvider::Github => "https://github.com/login/device/code",
}
}
pub fn grant_uri(&self) -> &'static str {
match self {
AuthProvider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token",
AuthProvider::Github => "https://github.com/login/oauth/access_token",
}
}
pub fn get_default_scopes(&self) -> String {
match self {
AuthProvider::Microsoft => format!(
"{}/.default+offline_access+profile+openid",
PROD_FIRST_PARTY_APP_ID
),
AuthProvider::Github => "read:user+read:org".to_string(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StoredCredential {
#[serde(rename = "p")]
provider: AuthProvider,
#[serde(rename = "a")]
access_token: String,
#[serde(rename = "r")]
refresh_token: Option<String>,
#[serde(rename = "e")]
expires_at: Option<DateTime<Utc>>,
}
impl StoredCredential {
pub async fn is_expired(&self, client: &reqwest::Client) -> bool {
match self.provider {
AuthProvider::Microsoft => self
.expires_at
.map(|e| Utc::now() + chrono::Duration::minutes(5) > e)
.unwrap_or(false),
// Make an auth request to Github. Mark the credential as expired
// only on a verifiable 4xx code. We don't error on any failed
// request since then a drop in connection could "require" a refresh
AuthProvider::Github => client
.get("https://api.github.com/user")
.header("Authorization", format!("token {}", self.access_token))
.header("User-Agent", get_default_user_agent())
.send()
.await
.map(|r| r.status().is_client_error())
.unwrap_or(false),
}
}
fn from_response(auth: AuthenticationResponse, provider: AuthProvider) -> Self {
StoredCredential {
provider,
access_token: auth.access_token,
refresh_token: auth.refresh_token,
expires_at: auth.expires_in.map(|e| Utc::now() + Duration::seconds(e)),
}
}
}
struct StorageWithLastRead {
storage: Box<dyn StorageImplementation>,
last_read: Cell<Result<Option<StoredCredential>, WrappedError>>,
}
#[derive(Clone)]
pub struct Auth {
client: reqwest::Client,
log: log::Logger,
file_storage_path: PathBuf,
storage: Arc<std::sync::Mutex<Option<StorageWithLastRead>>>,
}
trait StorageImplementation: Send + Sync {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError>;
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError>;
fn clear(&mut self) -> Result<(), WrappedError>;
}
// unseal decrypts and deserializes the value
fn seal<T>(value: &T) -> String
where
T: Serialize + ?Sized,
{
let dec = serde_json::to_string(value).expect("expected to serialize");
encrypt(&dec)
}
// unseal decrypts and deserializes the value
fn unseal<T>(value: &str) -> Option<T>
where
T: DeserializeOwned,
{
// small back-compat for old unencrypted values
if let Ok(v) = serde_json::from_str::<T>(value) {
return Some(v);
}
let dec = decrypt(value)?;
serde_json::from_str::<T>(&dec).ok()
}
#[cfg(target_os = "windows")]
const KEYCHAIN_ENTRY_LIMIT: usize = 1024;
#[cfg(not(target_os = "windows"))]
const KEYCHAIN_ENTRY_LIMIT: usize = 128 * 1024;
const CONTINUE_MARKER: &str = "<MORE>";
#[derive(Default)]
struct KeyringStorage {
// keywring storage can be split into multiple entries due to entry length limits
// on Windows https://github.com/microsoft/vscode-cli/issues/358
entries: Vec<keyring::Entry>,
}
macro_rules! get_next_entry {
($self: expr, $i: expr) => {
match $self.entries.get($i) {
Some(e) => e,
None => {
let e = keyring::Entry::new("vscode-cli", &format!("vscode-cli-{}", $i));
$self.entries.push(e);
$self.entries.last().unwrap()
}
}
};
}
impl StorageImplementation for KeyringStorage {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
let mut str = String::new();
for i in 0.. {
let entry = get_next_entry!(self, i);
let next_chunk = match entry.get_password() {
Ok(value) => value,
Err(keyring::Error::NoEntry) => return Ok(None), // missing entries?
Err(e) => return Err(wrap(e, "error reading keyring")),
};
if next_chunk.ends_with(CONTINUE_MARKER) {
str.push_str(&next_chunk[..next_chunk.len() - CONTINUE_MARKER.len()]);
} else {
str.push_str(&next_chunk);
break;
}
}
Ok(unseal(&str))
}
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
let sealed = seal(&value);
let step_size = KEYCHAIN_ENTRY_LIMIT - CONTINUE_MARKER.len();
for i in (0..sealed.len()).step_by(step_size) {
let entry = get_next_entry!(self, i / step_size);
let cutoff = i + step_size;
let stored = if cutoff <= sealed.len() {
let mut part = sealed[i..cutoff].to_string();
part.push_str(CONTINUE_MARKER);
entry.set_password(&part)
} else {
entry.set_password(&sealed[i..])
};
if let Err(e) = stored {
return Err(wrap(e, "error updating keyring"));
}
}
Ok(())
}
fn clear(&mut self) -> Result<(), WrappedError> {
self.read().ok(); // make sure component parts are available
for entry in self.entries.iter() {
entry
.delete_password()
.map_err(|e| wrap(e, "error updating keyring"))?;
}
self.entries.clear();
Ok(())
}
}
struct FileStorage(PersistedState<Option<String>>);
impl StorageImplementation for FileStorage {
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
Ok(self.0.load().and_then(|s| unseal(&s)))
}
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
self.0.save(Some(seal(&value)))
}
fn clear(&mut self) -> Result<(), WrappedError> {
self.0.save(None)
}
}
impl Auth {
pub fn new(paths: &LauncherPaths, log: log::Logger) -> Auth {
Auth {
log,
client: reqwest::Client::new(),
file_storage_path: paths.root().join("token.json"),
storage: Arc::new(std::sync::Mutex::new(None)),
}
}
fn with_storage<T, F>(&self, op: F) -> T
where
F: FnOnce(&mut StorageWithLastRead) -> T,
{
let mut opt = self.storage.lock().unwrap();
if let Some(s) = opt.as_mut() {
return op(s);
}
let mut keyring_storage = KeyringStorage::default();
let mut file_storage = FileStorage(PersistedState::new(self.file_storage_path.clone()));
let keyring_storage_result = match std::env::var("LAUNCHER_USE_FILE_KEYCHAIN") {
Ok(_) => Err(wrap("", "user prefers file storage")),
_ => keyring_storage.read(),
};
let mut storage = match keyring_storage_result {
Ok(v) => StorageWithLastRead {
last_read: Cell::new(Ok(v)),
storage: Box::new(keyring_storage),
},
Err(_) => StorageWithLastRead {
last_read: Cell::new(file_storage.read()),
storage: Box::new(file_storage),
},
};
let out = op(&mut storage);
*opt = Some(storage);
out
}
/// Gets a tunnel Authentication for use in the tunnel management API.
pub async fn get_tunnel_authentication(&self) -> Result<Authorization, AnyError> {
let cred = self.get_credential().await?;
let auth = match cred.provider {
AuthProvider::Microsoft => Authorization::Bearer(cred.access_token),
AuthProvider::Github => Authorization::Github(format!(
"client_id={} {}",
cred.provider.client_id(),
cred.access_token
)),
};
Ok(auth)
}
/// Reads the current details from the keyring.
pub fn get_current_credential(&self) -> Result<Option<StoredCredential>, WrappedError> {
self.with_storage(|storage| {
let value = storage.last_read.replace(Ok(None));
storage.last_read.set(value.clone());
value
})
}
/// Clears login info from the keyring.
pub fn clear_credentials(&self) -> Result<(), WrappedError> {
self.with_storage(|storage| {
storage.storage.clear()?;
storage.last_read.set(Ok(None));
Ok(())
})
}
/// Runs the login flow, optionally pre-filling a provider and/or access token.
pub async fn login(
&self,
provider: Option<AuthProvider>,
access_token: Option<String>,
) -> Result<StoredCredential, AnyError> {
let provider = match provider {
Some(p) => p,
None => self.prompt_for_provider().await?,
};
let credentials = match access_token {
Some(t) => StoredCredential {
provider,
access_token: t,
refresh_token: None,
expires_at: None,
},
None => self.do_device_code_flow_with_provider(provider).await?,
};
self.store_credentials(credentials.clone());
Ok(credentials)
}
/// Gets the currently stored credentials, or asks the user to log in.
pub async fn get_credential(&self) -> Result<StoredCredential, AnyError> {
let entry = match self.get_current_credential() {
Ok(Some(old_creds)) => {
trace!(self.log, "Found token in keyring");
match self.get_refreshed_token(&old_creds).await {
Ok(Some(new_creds)) => {
self.store_credentials(new_creds.clone());
new_creds
}
Ok(None) => old_creds,
Err(e) => {
info!(self.log, "error refreshing token: {}", e);
let new_creds = self
.do_device_code_flow_with_provider(old_creds.provider)
.await?;
self.store_credentials(new_creds.clone());
new_creds
}
}
}
Ok(None) => {
trace!(self.log, "No token in keyring, getting a new one");
let creds = self.do_device_code_flow().await?;
self.store_credentials(creds.clone());
creds
}
Err(e) => {
warning!(
self.log,
"Error reading token from keyring, getting a new one: {}",
e
);
let creds = self.do_device_code_flow().await?;
self.store_credentials(creds.clone());
creds
}
};
Ok(entry)
}
/// Stores credentials, logging a warning if it fails.
fn store_credentials(&self, creds: StoredCredential) {
self.with_storage(|storage| {
if let Err(e) = storage.storage.store(creds.clone()) {
warning!(
self.log,
"Failed to update keyring with new credentials: {}",
e
);
}
storage.last_read.set(Ok(Some(creds)));
})
}
/// Refreshes the token in the credentials if necessary. Returns None if
/// the token is up to date, or Some new token otherwise.
async fn get_refreshed_token(
&self,
creds: &StoredCredential,
) -> Result<Option<StoredCredential>, AnyError> {
if !creds.is_expired(&self.client).await {
return Ok(None);
}
let refresh_token = match &creds.refresh_token {
Some(t) => t,
None => return Err(AnyError::from(RefreshTokenNotAvailableError())),
};
self.do_grant(
creds.provider,
format!(
"client_id={}&grant_type=refresh_token&refresh_token={}",
creds.provider.client_id(),
refresh_token
),
)
.await
.map(Some)
}
/// Does a "grant token" request.
async fn do_grant(
&self,
provider: AuthProvider,
body: String,
) -> Result<StoredCredential, AnyError> {
let response = self
.client
.post(provider.grant_uri())
.body(body)
.header("Accept", "application/json")
.send()
.await?;
if !response.status().is_success() {
return Err(StatusError::from_res(response).await?.into());
}
let body = response.json::<AuthenticationResponse>().await?;
Ok(StoredCredential::from_response(body, provider))
}
/// Implements the device code flow, returning the credentials upon success.
async fn do_device_code_flow(&self) -> Result<StoredCredential, AnyError> {
let provider = self.prompt_for_provider().await?;
self.do_device_code_flow_with_provider(provider).await
}
async fn prompt_for_provider(&self) -> Result<AuthProvider, AnyError> {
if std::env::var("LAUNCHER_ALLOW_MS_AUTH").is_err() {
return Ok(AuthProvider::Github);
}
let provider = prompt_options(
"How would you like to log in to VS Code?",
&[AuthProvider::Microsoft, AuthProvider::Github],
)?;
Ok(provider)
}
async fn do_device_code_flow_with_provider(
&self,
provider: AuthProvider,
) -> Result<StoredCredential, AnyError> {
loop {
let init_code = self
.client
.post(provider.code_uri())
.header("Accept", "application/json")
.body(format!(
"client_id={}&scope={}",
provider.client_id(),
provider.get_default_scopes(),
))
.send()
.await?;
if !init_code.status().is_success() {
return Err(StatusError::from_res(init_code).await?.into());
}
let init_code_json = init_code.json::<DeviceCodeResponse>().await?;
let expires_at = Utc::now() + chrono::Duration::seconds(init_code_json.expires_in);
match &init_code_json.message {
Some(m) => self.log.result(m),
None => self.log.result(&format!(
"To grant access to the server, please log into {} and use code {}",
init_code_json.verification_uri, init_code_json.user_code
)),
};
let body = format!(
"client_id={}&grant_type=urn:ietf:params:oauth:grant-type:device_code&device_code={}",
provider.client_id(),
init_code_json.device_code
);
while Utc::now() < expires_at {
sleep(std::time::Duration::from_secs(5)).await;
match self.do_grant(provider, body.clone()).await {
Ok(creds) => return Ok(creds),
Err(e) => {
trace!(self.log, "refresh poll failed, retrying: {}", e);
}
}
}
}
}
}
#[async_trait]
impl AuthorizationProvider for Auth {
async fn get_authorization(&self) -> Result<Authorization, HttpError> {
self.get_tunnel_authentication()
.await
.map_err(|e| HttpError::AuthorizationError(e.to_string()))
}
}
lazy_static::lazy_static! {
static ref HOSTNAME: Vec<u8> = gethostname().to_string_lossy().bytes().collect();
}
#[cfg(feature = "vscode-encrypt")]
fn encrypt(value: &str) -> String {
vscode_encrypt::encrypt(&HOSTNAME, value.as_bytes()).expect("expected to encrypt")
}
#[cfg(feature = "vscode-encrypt")]
fn decrypt(value: &str) -> Option<String> {
let b = vscode_encrypt::decrypt(&HOSTNAME, value).ok()?;
String::from_utf8(b).ok()
}
#[cfg(not(feature = "vscode-encrypt"))]
fn encrypt(value: &str) -> String {
value.to_owned()
}
#[cfg(not(feature = "vscode-encrypt"))]
fn decrypt(value: &str) -> Option<String> {
Some(value.to_owned())
}

View file

@ -0,0 +1,123 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use clap::Parser;
use cli::{
commands::{args, tunnels, CommandContext},
constants, log as own_log,
state::LauncherPaths,
};
use opentelemetry::sdk::trace::TracerProvider as SdkTracerProvider;
use opentelemetry::trace::TracerProvider;
use log::{Level, Metadata, Record};
#[derive(Parser, Debug)]
#[clap(
long_about = None,
name = "Visual Studio Code Tunnels CLI",
version = match constants::LAUNCHER_VERSION { Some(v) => v, None => "dev" },
)]
pub struct TunnelCli {
#[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))]
pub global_options: args::GlobalOptions,
#[clap(flatten, next_help_heading = Some("TUNNEL OPTIONS"))]
pub tunnel_options: args::TunnelArgs,
}
/// Entrypoint for a standalone "code-tunnel" subcommand. This is a temporary
/// artifact until we're ready to do swap to the full "code" CLI, and most
/// code in here is duplicated from `src/bin/code/main.rs`
#[tokio::main]
async fn main() -> Result<(), std::convert::Infallible> {
let parsed = TunnelCli::parse();
let context = CommandContext {
http: reqwest::Client::new(),
paths: LauncherPaths::new(&parsed.global_options.cli_data_dir).unwrap(),
log: own_log::Logger::new(
SdkTracerProvider::builder().build().tracer("codecli"),
if parsed.global_options.verbose {
own_log::Level::Trace
} else {
parsed.global_options.log.unwrap_or(own_log::Level::Info)
},
),
args: args::Cli {
global_options: parsed.global_options,
subcommand: Some(args::Commands::Tunnel(parsed.tunnel_options.clone())),
..Default::default()
},
};
log::set_logger(Box::leak(Box::new(RustyLogger(context.log.clone()))))
.map(|()| log::set_max_level(log::LevelFilter::Debug))
.expect("expected to make logger");
let result = match parsed.tunnel_options.subcommand {
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context).await,
Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context).await,
Some(args::TunnelSubcommand::Rename(rename_args)) => {
tunnels::rename(context, rename_args).await
}
Some(args::TunnelSubcommand::User(user_command)) => {
tunnels::user(context, user_command).await
}
Some(args::TunnelSubcommand::Service(service_args)) => {
tunnels::service(context, service_args).await
}
None => tunnels::serve(context, parsed.tunnel_options.serve_args).await,
};
match result {
Err(e) => print_and_exit(e),
Ok(code) => std::process::exit(code),
}
}
fn print_and_exit<E>(err: E) -> !
where
E: std::fmt::Display,
{
own_log::emit(own_log::Level::Error, "", &format!("{}", err));
std::process::exit(1);
}
/// Logger that uses the common rust "log" crate and directs back to one of
/// our managed loggers.
struct RustyLogger(own_log::Logger);
impl log::Log for RustyLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= Level::Debug
}
fn log(&self, record: &Record) {
if !self.enabled(record.metadata()) {
return;
}
// exclude noisy log modules:
let src = match record.module_path() {
Some("russh::cipher") => return,
Some("russh::negotiation") => return,
Some(s) => s,
None => "<unknown>",
};
self.0.emit(
match record.level() {
log::Level::Debug => own_log::Level::Debug,
log::Level::Error => own_log::Level::Error,
log::Level::Info => own_log::Level::Info,
log::Level::Trace => own_log::Level::Trace,
log::Level::Warn => own_log::Level::Warn,
},
&format!("[{}] {}", src, record.args()),
);
}
fn flush(&self) {}
}

View file

@ -0,0 +1,234 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::collections::HashMap;
use cli::commands::args::{
Cli, Commands, DesktopCodeOptions, ExtensionArgs, ExtensionSubcommand, InstallExtensionArgs,
ListExtensionArgs, UninstallExtensionArgs,
};
/// Tries to parse the argv using the legacy CLI interface, looking for its
/// flags and generating a CLI with subcommands if those don't exist.
pub fn try_parse_legacy(
iter: impl IntoIterator<Item = impl Into<std::ffi::OsString>>,
) -> Option<Cli> {
let raw = clap_lex::RawArgs::new(iter);
let mut cursor = raw.cursor();
raw.next(&mut cursor); // Skip the bin
// First make a hashmap of all flags and capture positional arguments.
let mut args: HashMap<String, Vec<String>> = HashMap::new();
let mut last_arg = None;
while let Some(arg) = raw.next(&mut cursor) {
if let Some((long, value)) = arg.to_long() {
if let Ok(long) = long {
last_arg = Some(long.to_string());
match args.get_mut(long) {
Some(prev) => {
if let Some(v) = value {
prev.push(v.to_str_lossy().to_string());
}
}
None => {
if let Some(v) = value {
args.insert(long.to_string(), vec![v.to_str_lossy().to_string()]);
} else {
args.insert(long.to_string(), vec![]);
}
}
}
}
} else if let Ok(value) = arg.to_value() {
if let Some(last_arg) = &last_arg {
args.get_mut(last_arg)
.expect("expected to have last arg")
.push(value.to_string());
}
}
}
let get_first_arg_value =
|key: &str| args.get(key).and_then(|v| v.first()).map(|s| s.to_string());
let desktop_code_options = DesktopCodeOptions {
extensions_dir: get_first_arg_value("extensions-dir"),
user_data_dir: get_first_arg_value("user-data-dir"),
use_version: None,
};
// Now translate them to subcommands.
// --list-extensions -> ext list
// --install-extension=id -> ext install <id>
// --uninstall-extension=id -> ext uninstall <id>
// --status -> status
if args.contains_key("list-extensions") {
Some(Cli {
subcommand: Some(Commands::Extension(ExtensionArgs {
subcommand: ExtensionSubcommand::List(ListExtensionArgs {
category: get_first_arg_value("category"),
show_versions: args.contains_key("show-versions"),
}),
desktop_code_options,
})),
..Default::default()
})
} else if let Some(exts) = args.remove("install-extension") {
Some(Cli {
subcommand: Some(Commands::Extension(ExtensionArgs {
subcommand: ExtensionSubcommand::Install(InstallExtensionArgs {
id_or_path: exts,
pre_release: args.contains_key("pre-release"),
force: args.contains_key("force"),
}),
desktop_code_options,
})),
..Default::default()
})
} else if let Some(exts) = args.remove("uninstall-extension") {
Some(Cli {
subcommand: Some(Commands::Extension(ExtensionArgs {
subcommand: ExtensionSubcommand::Uninstall(UninstallExtensionArgs { id: exts }),
desktop_code_options,
})),
..Default::default()
})
} else if args.contains_key("status") {
Some(Cli {
subcommand: Some(Commands::Status),
..Default::default()
})
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parses_list_extensions() {
let args = vec![
"code",
"--list-extensions",
"--category",
"themes",
"--show-versions",
];
let cli = try_parse_legacy(args.into_iter()).unwrap();
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
if let ExtensionSubcommand::List(list_args) = extension_args.subcommand {
assert_eq!(list_args.category, Some("themes".to_string()));
assert!(list_args.show_versions);
} else {
panic!(
"Expected list subcommand, got {:?}",
extension_args.subcommand
);
}
} else {
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
}
}
#[test]
fn test_parses_install_extension() {
let args = vec![
"code",
"--install-extension",
"connor4312.codesong",
"connor4312.hello-world",
"--pre-release",
"--force",
];
let cli = try_parse_legacy(args.into_iter()).unwrap();
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
if let ExtensionSubcommand::Install(install_args) = extension_args.subcommand {
assert_eq!(
install_args.id_or_path,
vec!["connor4312.codesong", "connor4312.hello-world"]
);
assert!(install_args.pre_release);
assert!(install_args.force);
} else {
panic!(
"Expected install subcommand, got {:?}",
extension_args.subcommand
);
}
} else {
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
}
}
#[test]
fn test_parses_uninstall_extension() {
let args = vec!["code", "--uninstall-extension", "connor4312.codesong"];
let cli = try_parse_legacy(args.into_iter()).unwrap();
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand {
assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]);
} else {
panic!(
"Expected uninstall subcommand, got {:?}",
extension_args.subcommand
);
}
} else {
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
}
}
#[test]
fn test_parses_user_data_dir_and_extensions_dir() {
let args = vec![
"code",
"--uninstall-extension",
"connor4312.codesong",
"--user-data-dir",
"foo",
"--extensions-dir",
"bar",
];
let cli = try_parse_legacy(args.into_iter()).unwrap();
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
assert_eq!(
extension_args.desktop_code_options.user_data_dir,
Some("foo".to_string())
);
assert_eq!(
extension_args.desktop_code_options.extensions_dir,
Some("bar".to_string())
);
if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand {
assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]);
} else {
panic!(
"Expected uninstall subcommand, got {:?}",
extension_args.subcommand
);
}
} else {
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
}
}
#[test]
fn test_status() {
let args = vec!["code", "--status"];
let cli = try_parse_legacy(args.into_iter()).unwrap();
if let Some(Commands::Status) = cli.subcommand {
// no-op
} else {
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
}
}
}

View file

@ -0,0 +1,169 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
mod legacy_args;
use std::process::Command;
use clap::Parser;
use cli::{
commands::{args, tunnels, version, CommandContext},
desktop, log as own_log,
state::LauncherPaths,
update_service::UpdateService,
util::{
errors::{wrap, AnyError},
prereqs::PreReqChecker,
},
};
use legacy_args::try_parse_legacy;
use opentelemetry::sdk::trace::TracerProvider as SdkTracerProvider;
use opentelemetry::trace::TracerProvider;
use log::{Level, Metadata, Record};
#[tokio::main]
async fn main() -> Result<(), std::convert::Infallible> {
let raw_args = std::env::args_os().collect::<Vec<_>>();
let parsed = try_parse_legacy(&raw_args).unwrap_or_else(|| args::Cli::parse_from(&raw_args));
let context = CommandContext {
http: reqwest::Client::new(),
paths: LauncherPaths::new(&parsed.global_options.cli_data_dir).unwrap(),
log: own_log::Logger::new(
SdkTracerProvider::builder().build().tracer("codecli"),
if parsed.global_options.verbose {
own_log::Level::Trace
} else {
parsed.global_options.log.unwrap_or(own_log::Level::Info)
},
),
args: parsed,
};
log::set_logger(Box::leak(Box::new(RustyLogger(context.log.clone()))))
.map(|()| log::set_max_level(log::LevelFilter::Debug))
.expect("expected to make logger");
let result = match context.args.subcommand.clone() {
None => {
let ca = context.args.get_base_code_args();
start_code(context, ca).await
}
Some(args::Commands::Extension(extension_args)) => {
let mut ca = context.args.get_base_code_args();
extension_args.add_code_args(&mut ca);
start_code(context, ca).await
}
Some(args::Commands::Status) => {
let mut ca = context.args.get_base_code_args();
ca.push("--status".to_string());
start_code(context, ca).await
}
Some(args::Commands::Version(version_args)) => match version_args.subcommand {
args::VersionSubcommand::Use(use_version_args) => {
version::switch_to(context, use_version_args).await
}
args::VersionSubcommand::Uninstall(uninstall_version_args) => {
version::uninstall(context, uninstall_version_args).await
}
args::VersionSubcommand::List(list_version_args) => {
version::list(context, list_version_args).await
}
},
Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand {
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context).await,
Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context).await,
Some(args::TunnelSubcommand::Rename(rename_args)) => {
tunnels::rename(context, rename_args).await
}
Some(args::TunnelSubcommand::User(user_command)) => {
tunnels::user(context, user_command).await
}
Some(args::TunnelSubcommand::Service(service_args)) => {
tunnels::service(context, service_args).await
}
None => tunnels::serve(context, tunnel_args.serve_args).await,
},
};
match result {
Err(e) => print_and_exit(e),
Ok(code) => std::process::exit(code),
}
}
fn print_and_exit<E>(err: E) -> !
where
E: std::fmt::Display,
{
own_log::emit(own_log::Level::Error, "", &format!("{}", err));
std::process::exit(1);
}
async fn start_code(context: CommandContext, args: Vec<String>) -> Result<i32, AnyError> {
let platform = PreReqChecker::new().verify().await?;
let version_manager = desktop::CodeVersionManager::new(&context.paths, platform);
let update_service = UpdateService::new(context.log.clone(), context.http.clone());
let version = match &context.args.editor_options.code_options.use_version {
Some(v) => desktop::RequestedVersion::try_from(v.as_str())?,
None => version_manager.get_preferred_version(),
};
let binary = match version_manager.try_get_entrypoint(&version).await {
Some(ep) => ep,
None => {
desktop::prompt_to_install(&version)?;
version_manager.install(&update_service, &version).await?
}
};
let code = Command::new(binary)
.args(args)
.status()
.map(|s| s.code().unwrap_or(1))
.map_err(|e| wrap(e, "error running VS Code"))?;
Ok(code)
}
/// Logger that uses the common rust "log" crate and directs back to one of
/// our managed loggers.
struct RustyLogger(own_log::Logger);
impl log::Log for RustyLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= Level::Debug
}
fn log(&self, record: &Record) {
if !self.enabled(record.metadata()) {
return;
}
// exclude noisy log modules:
let src = match record.module_path() {
Some("russh::cipher") => return,
Some("russh::negotiation") => return,
Some(s) => s,
None => "<unknown>",
};
self.0.emit(
match record.level() {
log::Level::Debug => own_log::Level::Debug,
log::Level::Error => own_log::Level::Error,
log::Level::Info => own_log::Level::Info,
log::Level::Trace => own_log::Level::Trace,
log::Level::Warn => own_log::Level::Warn,
},
&format!("[{}] {}", src, record.args()),
);
}
fn flush(&self) {}
}

12
src/cli/src/commands.rs Normal file
View file

@ -0,0 +1,12 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
mod context;
mod output;
pub mod args;
pub mod tunnels;
pub mod version;
pub use context::CommandContext;

View file

@ -0,0 +1,590 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::fmt;
use crate::{constants, log, options, tunnels::code_server::CodeServerArgs};
use clap::{ArgEnum, Args, Parser, Subcommand};
const TEMPLATE: &str = "
Visual Studio Code CLI - {version}
Usage: code-insiders.exe [options][paths...]
To read output from another program, append '-' (e.g. 'echo Hello World | code-insiders.exe -')
{all-args}";
#[derive(Parser, Debug, Default)]
#[clap(
help_template = TEMPLATE,
long_about = None,
name = "Visual Studio Code CLI",
version = match constants::LAUNCHER_VERSION { Some(v) => v, None => "dev" },
)]
pub struct Cli {
/// One or more files, folders, or URIs to open.
#[clap(name = "paths")]
pub open_paths: Vec<String>,
#[clap(flatten, next_help_heading = Some("EDITOR OPTIONS"))]
pub editor_options: EditorOptions,
#[clap(flatten, next_help_heading = Some("EDITOR TROUBLESHOOTING"))]
pub troubleshooting: EditorTroubleshooting,
#[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))]
pub global_options: GlobalOptions,
#[clap(subcommand)]
pub subcommand: Option<Commands>,
}
impl Cli {
pub fn get_base_code_args(&self) -> Vec<String> {
let mut args = self.open_paths.clone();
self.editor_options.add_code_args(&mut args);
self.troubleshooting.add_code_args(&mut args);
self.global_options.add_code_args(&mut args);
args
}
}
impl<'a> From<&'a Cli> for CodeServerArgs {
fn from(cli: &'a Cli) -> Self {
let mut args = CodeServerArgs {
log: cli.global_options.log,
accept_server_license_terms: true,
..Default::default()
};
args.log = cli.global_options.log;
args.accept_server_license_terms = true;
if cli.global_options.verbose {
args.verbose = true;
}
if cli.global_options.disable_telemetry {
args.telemetry_level = Some(options::TelemetryLevel::Off);
} else if cli.global_options.telemetry_level.is_some() {
args.telemetry_level = cli.global_options.telemetry_level;
}
args
}
}
#[derive(Subcommand, Debug, Clone)]
pub enum Commands {
/// Create a tunnel that's accessible on vscode.dev from anywhere.
/// Run `code tunnel --help` for more usage info.
Tunnel(TunnelArgs),
/// Manage VS Code extensions.
#[clap(name = "ext")]
Extension(ExtensionArgs),
/// Print process usage and diagnostics information.
Status,
/// Changes the version of VS Code you're using.
Version(VersionArgs),
}
#[derive(Args, Debug, Clone)]
pub struct ExtensionArgs {
#[clap(subcommand)]
pub subcommand: ExtensionSubcommand,
#[clap(flatten)]
pub desktop_code_options: DesktopCodeOptions,
}
impl ExtensionArgs {
pub fn add_code_args(&self, target: &mut Vec<String>) {
if let Some(ed) = &self.desktop_code_options.extensions_dir {
target.push(ed.to_string());
}
self.subcommand.add_code_args(target);
}
}
#[derive(Subcommand, Debug, Clone)]
pub enum ExtensionSubcommand {
/// List installed extensions.
List(ListExtensionArgs),
/// Install an extension.
Install(InstallExtensionArgs),
/// Uninstall an extension.
Uninstall(UninstallExtensionArgs),
}
impl ExtensionSubcommand {
pub fn add_code_args(&self, target: &mut Vec<String>) {
match self {
ExtensionSubcommand::List(args) => {
target.push("--list-extensions".to_string());
if args.show_versions {
target.push("--show-versions".to_string());
}
if let Some(category) = &args.category {
target.push(format!("--category={}", category));
}
}
ExtensionSubcommand::Install(args) => {
for id in args.id_or_path.iter() {
target.push(format!("--install-extension={}", id));
}
if args.pre_release {
target.push("--pre-release".to_string());
}
if args.force {
target.push("--force".to_string());
}
}
ExtensionSubcommand::Uninstall(args) => {
for id in args.id.iter() {
target.push(format!("--uninstall-extension={}", id));
}
}
}
}
}
#[derive(Args, Debug, Clone)]
pub struct ListExtensionArgs {
/// Filters installed extensions by provided category, when using --list-extensions.
#[clap(long, value_name = "category")]
pub category: Option<String>,
/// Show versions of installed extensions, when using --list-extensions.
#[clap(long)]
pub show_versions: bool,
}
#[derive(Args, Debug, Clone)]
pub struct InstallExtensionArgs {
/// Either an extension id or a path to a VSIX. The identifier of an
/// extension is '${publisher}.${name}'. Use '--force' argument to update
/// to latest version. To install a specific version provide '@${version}'.
/// For example: 'vscode.csharp@1.2.3'.
#[clap(name = "ext-id | id")]
pub id_or_path: Vec<String>,
/// Installs the pre-release version of the extension
#[clap(long)]
pub pre_release: bool,
/// Update to the latest version of the extension if it's already installed.
#[clap(long)]
pub force: bool,
}
#[derive(Args, Debug, Clone)]
pub struct UninstallExtensionArgs {
/// One or more extension identifiers to uninstall. The identifier of an
/// extension is '${publisher}.${name}'. Use '--force' argument to update
/// to latest version.
#[clap(name = "ext-id")]
pub id: Vec<String>,
}
#[derive(Args, Debug, Clone)]
pub struct VersionArgs {
#[clap(subcommand)]
pub subcommand: VersionSubcommand,
}
#[derive(Subcommand, Debug, Clone)]
pub enum VersionSubcommand {
/// Switches the instance of VS Code in use.
Use(UseVersionArgs),
/// Uninstalls a instance of VS Code.
Uninstall(UninstallVersionArgs),
/// Lists installed VS Code instances.
List(OutputFormatOptions),
}
#[derive(Args, Debug, Clone)]
pub struct UseVersionArgs {
/// The version of VS Code you want to use. Can be "stable", "insiders",
/// a version number, or an absolute path to an existing install.
#[clap(value_name = "stable | insiders | x.y.z | path")]
pub name: String,
/// The directory the version should be installed into, if it's not already installed.
#[clap(long, value_name = "path")]
pub install_dir: Option<String>,
/// Reinstall the version even if it's already installed.
#[clap(long)]
pub reinstall: bool,
}
#[derive(Args, Debug, Clone)]
pub struct UninstallVersionArgs {
/// The version of VS Code to uninstall. Can be "stable", "insiders", or a
/// version number previous passed to `code version use <version>`.
#[clap(value_name = "stable | insiders | x.y.z")]
pub name: String,
}
#[derive(Args, Debug, Default)]
pub struct EditorOptions {
/// Compare two files with each other.
#[clap(short, long, value_names = &["file", "file"])]
pub diff: Vec<String>,
/// Add folder(s) to the last active window.
#[clap(short, long, value_name = "folder")]
pub add: Option<String>,
/// Open a file at the path on the specified line and character position.
#[clap(short, long, value_name = "file:line[:character]")]
pub goto: Option<String>,
/// Force to open a new window.
#[clap(short, long)]
pub new_window: bool,
/// Force to open a file or folder in an
#[clap(short, long)]
pub reuse_window: bool,
/// Wait for the files to be closed before returning.
#[clap(short, long)]
pub wait: bool,
/// The locale to use (e.g. en-US or zh-TW).
#[clap(long, value_name = "locale")]
pub locale: Option<String>,
/// Enables proposed API features for extensions. Can receive one or
/// more extension IDs to enable individually.
#[clap(long, value_name = "ext-id")]
pub enable_proposed_api: Vec<String>,
#[clap(flatten)]
pub code_options: DesktopCodeOptions,
}
impl EditorOptions {
pub fn add_code_args(&self, target: &mut Vec<String>) {
if !self.diff.is_empty() {
target.push("--diff".to_string());
for file in self.diff.iter() {
target.push(file.clone());
}
}
if let Some(add) = &self.add {
target.push("--add".to_string());
target.push(add.clone());
}
if let Some(goto) = &self.goto {
target.push("--goto".to_string());
target.push(goto.clone());
}
if self.new_window {
target.push("--new-window".to_string());
}
if self.reuse_window {
target.push("--reuse-window".to_string());
}
if self.wait {
target.push("--wait".to_string());
}
if let Some(locale) = &self.locale {
target.push(format!("--locale={}", locale));
}
if !self.enable_proposed_api.is_empty() {
for id in self.enable_proposed_api.iter() {
target.push(format!("--enable-proposed-api={}", id));
}
}
self.code_options.add_code_args(target);
}
}
/// Arguments applicable whenever VS Code desktop is launched
#[derive(Args, Debug, Default, Clone)]
pub struct DesktopCodeOptions {
/// Set the root path for extensions.
#[clap(long, value_name = "dir")]
pub extensions_dir: Option<String>,
/// Specifies the directory that user data is kept in. Can be used to
/// open multiple distinct instances of Code.
#[clap(long, value_name = "dir")]
pub user_data_dir: Option<String>,
/// Sets the VS Code version to use for this command. The preferred version
/// can be persisted with `code version use <version>`. Can be "stable",
/// "insiders", a version number, or an absolute path to an existing install.
#[clap(long, value_name = "stable | insiders | x.y.z | path")]
pub use_version: Option<String>,
}
/// Argument specifying the output format.
#[derive(Args, Debug, Clone)]
pub struct OutputFormatOptions {
/// Set the data output formats.
#[clap(arg_enum, long, value_name = "format", default_value_t = OutputFormat::Text)]
pub format: OutputFormat,
}
impl DesktopCodeOptions {
pub fn add_code_args(&self, target: &mut Vec<String>) {
if let Some(extensions_dir) = &self.extensions_dir {
target.push(format!("--extensions-dir={}", extensions_dir));
}
if let Some(user_data_dir) = &self.user_data_dir {
target.push(format!("--user-data-dir={}", user_data_dir));
}
}
}
#[derive(Args, Debug, Default)]
pub struct GlobalOptions {
/// Directory where CLI metadata, such as VS Code installations, should be stored.
#[clap(long, env = "VSCODE_CLI_DATA_DIR", global = true)]
pub cli_data_dir: Option<String>,
/// Print verbose output (implies --wait).
#[clap(long, global = true)]
pub verbose: bool,
/// Log level to use.
#[clap(long, arg_enum, value_name = "level", global = true)]
pub log: Option<log::Level>,
/// Disable telemetry for the current command, even if it was previously
/// accepted as part of the license prompt or specified in '--telemetry-level'
#[clap(long, global = true, hide = true)]
pub disable_telemetry: bool,
/// Sets the initial telemetry level
#[clap(arg_enum, long, global = true, hide = true)]
pub telemetry_level: Option<options::TelemetryLevel>,
}
impl GlobalOptions {
pub fn add_code_args(&self, target: &mut Vec<String>) {
if self.verbose {
target.push("--verbose".to_string());
}
if let Some(log) = self.log {
target.push(format!("--log={}", log));
}
if self.disable_telemetry {
target.push("--disable-telemetry".to_string());
}
if let Some(telemetry_level) = &self.telemetry_level {
target.push(format!("--telemetry-level={}", telemetry_level));
}
}
}
#[derive(Args, Debug, Default)]
pub struct EditorTroubleshooting {
/// Run CPU profiler during startup.
#[clap(long)]
pub prof_startup: bool,
/// Disable all installed extensions.
#[clap(long)]
pub disable_extensions: bool,
/// Disable an extension.
#[clap(long, value_name = "ext-id")]
pub disable_extension: Vec<String>,
/// Turn sync on or off.
#[clap(arg_enum, long, value_name = "on | off")]
pub sync: Option<SyncState>,
/// Allow debugging and profiling of extensions. Check the developer tools for the connection URI.
#[clap(long, value_name = "port")]
pub inspect_extensions: Option<u16>,
/// Allow debugging and profiling of extensions with the extension host
/// being paused after start. Check the developer tools for the connection URI.
#[clap(long, value_name = "port")]
pub inspect_brk_extensions: Option<u16>,
/// Disable GPU hardware acceleration.
#[clap(long)]
pub disable_gpu: bool,
/// Max memory size for a window (in Mbytes).
#[clap(long, value_name = "memory")]
pub max_memory: Option<usize>,
/// Shows all telemetry events which VS code collects.
#[clap(long)]
pub telemetry: bool,
}
impl EditorTroubleshooting {
pub fn add_code_args(&self, target: &mut Vec<String>) {
if self.prof_startup {
target.push("--prof-startup".to_string());
}
if self.disable_extensions {
target.push("--disable-extensions".to_string());
}
for id in self.disable_extension.iter() {
target.push(format!("--disable-extension={}", id));
}
if let Some(sync) = &self.sync {
target.push(format!("--sync={}", sync));
}
if let Some(port) = &self.inspect_extensions {
target.push(format!("--inspect-extensions={}", port));
}
if let Some(port) = &self.inspect_brk_extensions {
target.push(format!("--inspect-brk-extensions={}", port));
}
if self.disable_gpu {
target.push("--disable-gpu".to_string());
}
if let Some(memory) = &self.max_memory {
target.push(format!("--max-memory={}", memory));
}
if self.telemetry {
target.push("--telemetry".to_string());
}
}
}
#[derive(ArgEnum, Clone, Copy, Debug)]
pub enum SyncState {
On,
Off,
}
impl fmt::Display for SyncState {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SyncState::Off => write!(f, "off"),
SyncState::On => write!(f, "on"),
}
}
}
#[derive(ArgEnum, Clone, Copy, Debug)]
pub enum OutputFormat {
Json,
Text,
}
#[derive(Args, Clone, Debug, Default)]
pub struct ExistingTunnelArgs {
/// Name you'd like to assign preexisting tunnel to use to connect the tunnel
#[clap(long, hide = true)]
pub tunnel_name: Option<String>,
/// Token to authenticate and use preexisting tunnel
#[clap(long, hide = true)]
pub host_token: Option<String>,
/// ID of preexisting tunnel to use to connect the tunnel
#[clap(long, hide = true)]
pub tunnel_id: Option<String>,
/// Cluster of preexisting tunnel to use to connect the tunnel
#[clap(long, hide = true)]
pub cluster: Option<String>,
}
#[derive(Args, Debug, Clone, Default)]
pub struct TunnelServeArgs {
/// Optional details to connect to an existing tunnel
#[clap(flatten, next_help_heading = Some("ADVANCED OPTIONS"))]
pub tunnel: ExistingTunnelArgs,
/// Randomly name machine for port forwarding service
#[clap(long)]
pub random_name: bool,
}
#[derive(Args, Debug, Clone)]
pub struct TunnelArgs {
#[clap(subcommand)]
pub subcommand: Option<TunnelSubcommand>,
#[clap(flatten)]
pub serve_args: TunnelServeArgs,
}
#[derive(Subcommand, Debug, Clone)]
pub enum TunnelSubcommand {
/// Delete all servers which are currently not running.
Prune,
/// Rename the name of this machine associated with port forwarding service.
Rename(TunnelRenameArgs),
/// Remove this machine's association with the port forwarding service.
Unregister,
#[clap(subcommand)]
User(TunnelUserSubCommands),
/// Manages the tunnel when installed as a system service,
#[clap(subcommand)]
Service(TunnelServiceSubCommands),
}
#[derive(Subcommand, Debug, Clone)]
pub enum TunnelServiceSubCommands {
/// Installs or re-installs the tunnel service on the machine.
Install,
/// Uninstalls and stops the tunnel service.
Uninstall,
/// Internal command for running the service
#[clap(hide = true)]
InternalRun,
}
#[derive(Args, Debug, Clone)]
pub struct TunnelRenameArgs {
/// The name you'd like to rename your machine to.
pub name: String,
}
#[derive(Subcommand, Debug, Clone)]
pub enum TunnelUserSubCommands {
/// Log in to port forwarding service
Login(LoginArgs),
/// Log out of port forwarding service
Logout,
/// Show the account that's logged into port forwarding service
Show,
}
#[derive(Args, Debug, Clone)]
pub struct LoginArgs {
/// An access token to store for authentication. Note: this will not be
/// refreshed if it expires!
#[clap(long, requires = "provider")]
pub access_token: Option<String>,
/// The auth provider to use. If not provided, a prompt will be shown.
#[clap(arg_enum, long)]
pub provider: Option<AuthProvider>,
}
#[derive(clap::ArgEnum, Debug, Clone, Copy)]
pub enum AuthProvider {
Microsoft,
Github,
}

View file

@ -0,0 +1,15 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::{log, state::LauncherPaths};
use super::args::Cli;
pub struct CommandContext {
pub log: log::Logger,
pub paths: LauncherPaths,
pub args: Cli,
pub http: reqwest::Client,
}

View file

@ -0,0 +1,135 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::fmt::Display;
use std::io::{BufWriter, Write};
use super::args::OutputFormat;
pub struct Column {
max_width: usize,
heading: &'static str,
data: Vec<String>,
}
impl Column {
pub fn new(heading: &'static str) -> Self {
Column {
max_width: heading.len(),
heading,
data: vec![],
}
}
pub fn add_row(&mut self, row: String) {
self.max_width = std::cmp::max(self.max_width, row.len());
self.data.push(row);
}
}
impl OutputFormat {
pub fn print_table(&self, table: OutputTable) -> Result<(), std::io::Error> {
match *self {
OutputFormat::Json => JsonTablePrinter().print(table, &mut std::io::stdout()),
OutputFormat::Text => TextTablePrinter().print(table, &mut std::io::stdout()),
}
}
}
pub struct OutputTable {
cols: Vec<Column>,
}
impl OutputTable {
pub fn new(cols: Vec<Column>) -> Self {
OutputTable { cols }
}
}
trait TablePrinter {
fn print(&self, table: OutputTable, out: &mut dyn std::io::Write)
-> Result<(), std::io::Error>;
}
pub struct JsonTablePrinter();
impl TablePrinter for JsonTablePrinter {
fn print(
&self,
table: OutputTable,
out: &mut dyn std::io::Write,
) -> Result<(), std::io::Error> {
let mut bw = BufWriter::new(out);
bw.write_all(b"[")?;
if !table.cols.is_empty() {
let data_len = table.cols[0].data.len();
for i in 0..data_len {
if i > 0 {
bw.write_all(b",{")?;
} else {
bw.write_all(b"{")?;
}
for col in &table.cols {
serde_json::to_writer(&mut bw, col.heading)?;
bw.write_all(b":")?;
serde_json::to_writer(&mut bw, &col.data[i])?;
}
}
}
bw.write_all(b"]")?;
bw.flush()
}
}
/// Type that prints the output as an ASCII, markdown-style table.
pub struct TextTablePrinter();
impl TablePrinter for TextTablePrinter {
fn print(
&self,
table: OutputTable,
out: &mut dyn std::io::Write,
) -> Result<(), std::io::Error> {
let mut bw = BufWriter::new(out);
let sizes = table.cols.iter().map(|c| c.max_width).collect::<Vec<_>>();
// print headers
write_columns(&mut bw, table.cols.iter().map(|c| c.heading), &sizes)?;
// print --- separators
write_columns(
&mut bw,
table.cols.iter().map(|c| "-".repeat(c.max_width)),
&sizes,
)?;
// print each column
if !table.cols.is_empty() {
let data_len = table.cols[0].data.len();
for i in 0..data_len {
write_columns(&mut bw, table.cols.iter().map(|c| &c.data[i]), &sizes)?;
}
}
bw.flush()
}
}
fn write_columns<T>(
mut w: impl Write,
cols: impl Iterator<Item = T>,
sizes: &[usize],
) -> Result<(), std::io::Error>
where
T: Display,
{
w.write_all(b"|")?;
for (i, col) in cols.enumerate() {
write!(w, " {:width$} |", col, width = sizes[i])?;
}
w.write_all(b"\r\n")
}

View file

@ -0,0 +1,261 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::process::Stdio;
use async_trait::async_trait;
use tokio::sync::oneshot;
use super::{
args::{
AuthProvider, Cli, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs,
TunnelServiceSubCommands, TunnelUserSubCommands,
},
CommandContext,
};
use crate::{
auth::Auth,
log::{self, Logger},
state::LauncherPaths,
tunnels::{
code_server::CodeServerArgs, create_service_manager, dev_tunnels, legal,
paths::get_all_servers, ServiceContainer, ServiceManager,
},
util::{
errors::{wrap, AnyError},
prereqs::PreReqChecker,
},
};
impl From<AuthProvider> for crate::auth::AuthProvider {
fn from(auth_provider: AuthProvider) -> Self {
match auth_provider {
AuthProvider::Github => crate::auth::AuthProvider::Github,
AuthProvider::Microsoft => crate::auth::AuthProvider::Microsoft,
}
}
}
impl From<ExistingTunnelArgs> for Option<dev_tunnels::ExistingTunnel> {
fn from(d: ExistingTunnelArgs) -> Option<dev_tunnels::ExistingTunnel> {
if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) =
(d.tunnel_id, d.tunnel_name, d.cluster, d.host_token)
{
Some(dev_tunnels::ExistingTunnel {
tunnel_id,
tunnel_name,
host_token,
cluster,
})
} else {
None
}
}
}
struct TunnelServiceContainer {
args: Cli,
}
impl TunnelServiceContainer {
fn new(args: Cli) -> Self {
Self { args }
}
}
#[async_trait]
impl ServiceContainer for TunnelServiceContainer {
async fn run_service(
&mut self,
log: log::Logger,
launcher_paths: LauncherPaths,
shutdown_rx: oneshot::Receiver<()>,
) -> Result<(), AnyError> {
let csa = (&self.args).into();
serve_with_csa(
launcher_paths,
log,
TunnelServeArgs {
random_name: true, // avoid prompting
..Default::default()
},
csa,
Some(shutdown_rx),
)
.await?;
Ok(())
}
}
pub async fn service(
ctx: CommandContext,
service_args: TunnelServiceSubCommands,
) -> Result<i32, AnyError> {
let manager = create_service_manager(ctx.log.clone());
match service_args {
TunnelServiceSubCommands::Install => {
// ensure logged in, otherwise subsequent serving will fail
Auth::new(&ctx.paths, ctx.log.clone())
.get_credential()
.await?;
// likewise for license consent
legal::require_consent(&ctx.paths)?;
let current_exe =
std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?;
manager.register(
current_exe,
&[
"--cli-data-dir",
ctx.paths.root().as_os_str().to_string_lossy().as_ref(),
"tunnel",
"service",
"internal-run",
],
)?;
ctx.log.result("Service successfully installed! You can use `code tunnel service log` to monitor it, and `code tunnel service uninstall` to remove it.");
}
TunnelServiceSubCommands::Uninstall => {
manager.unregister()?;
}
TunnelServiceSubCommands::InternalRun => {
manager.run(ctx.paths.clone(), TunnelServiceContainer::new(ctx.args))?;
}
}
Ok(0)
}
pub async fn user(ctx: CommandContext, user_args: TunnelUserSubCommands) -> Result<i32, AnyError> {
let auth = Auth::new(&ctx.paths, ctx.log.clone());
match user_args {
TunnelUserSubCommands::Login(login_args) => {
auth.login(
login_args.provider.map(|p| p.into()),
login_args.access_token.to_owned(),
)
.await?;
}
TunnelUserSubCommands::Logout => {
auth.clear_credentials()?;
}
TunnelUserSubCommands::Show => {
if let Ok(Some(_)) = auth.get_current_credential() {
ctx.log.result("logged in");
} else {
ctx.log.result("not logged in");
return Ok(1);
}
}
}
Ok(0)
}
/// Remove the tunnel used by this gateway, if any.
pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Result<i32, AnyError> {
let auth = Auth::new(&ctx.paths, ctx.log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
dt.rename_tunnel(&rename_args.name).await?;
ctx.log.result(&format!(
"Successfully renamed this gateway to {}",
&rename_args.name
));
Ok(0)
}
/// Remove the tunnel used by this gateway, if any.
pub async fn unregister(ctx: CommandContext) -> Result<i32, AnyError> {
let auth = Auth::new(&ctx.paths, ctx.log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
dt.remove_tunnel().await?;
Ok(0)
}
/// Removes unused servers.
pub async fn prune(ctx: CommandContext) -> Result<i32, AnyError> {
get_all_servers(&ctx.paths)
.into_iter()
.map(|s| s.server_paths(&ctx.paths))
.filter(|s| s.get_running_pid().is_none())
.try_for_each(|s| {
ctx.log
.result(&format!("Deleted {}", s.server_dir.display()));
s.delete()
})
.map_err(AnyError::from)?;
ctx.log.result("Successfully removed all unused servers");
Ok(0)
}
/// Starts the gateway server.
pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result<i32, AnyError> {
let CommandContext {
log, paths, args, ..
} = ctx;
legal::require_consent(&paths)?;
let csa = (&args).into();
serve_with_csa(paths, log, gateway_args, csa, None).await
}
async fn serve_with_csa(
paths: LauncherPaths,
log: Logger,
gateway_args: TunnelServeArgs,
csa: CodeServerArgs,
shutdown_rx: Option<oneshot::Receiver<()>>,
) -> Result<i32, AnyError> {
let platform = spanf!(log, log.span("prereq"), PreReqChecker::new().verify())?;
let auth = Auth::new(&paths, log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths);
let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() {
dt.start_existing_tunnel(d).await
} else {
dt.start_new_launcher_tunnel(gateway_args.random_name).await
}?;
let shutdown_tx = if let Some(tx) = shutdown_rx {
tx
} else {
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tx.send(()).ok();
});
rx
};
let mut r = crate::tunnels::serve(&log, tunnel, &paths, &csa, platform, shutdown_tx).await?;
r.tunnel.close().await.ok();
if r.respawn {
warning!(log, "respawn requested, starting new server");
// reuse current args, but specify no-forward since tunnels will
// already be running in this process, and we cannot do a login
let args = std::env::args().skip(1).collect::<Vec<String>>();
let exit = std::process::Command::new(std::env::current_exe().unwrap())
.args(args)
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.stdin(Stdio::inherit())
.spawn()
.map_err(|e| wrap(e, "error respawning after update"))?
.wait()
.map_err(|e| wrap(e, "error waiting for child"))?;
return Ok(exit.code().unwrap_or(1));
}
Ok(0)
}

View file

@ -0,0 +1,66 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::{
desktop::{CodeVersionManager, RequestedVersion},
log,
update_service::UpdateService,
util::{errors::AnyError, prereqs::PreReqChecker},
};
use super::{
args::{OutputFormatOptions, UninstallVersionArgs, UseVersionArgs},
output::{Column, OutputTable},
CommandContext,
};
pub async fn switch_to(ctx: CommandContext, args: UseVersionArgs) -> Result<i32, AnyError> {
let platform = PreReqChecker::new().verify().await?;
let vm = CodeVersionManager::new(&ctx.paths, platform);
let version = RequestedVersion::try_from(args.name.as_str())?;
if !args.reinstall && vm.try_get_entrypoint(&version).await.is_some() {
vm.set_preferred_version(&version)?;
print_now_using(&ctx.log, &version);
return Ok(0);
}
let update_service = UpdateService::new(ctx.log.clone(), ctx.http.clone());
vm.install(&update_service, &version).await?;
vm.set_preferred_version(&version)?;
print_now_using(&ctx.log, &version);
Ok(0)
}
pub async fn list(ctx: CommandContext, args: OutputFormatOptions) -> Result<i32, AnyError> {
let platform = PreReqChecker::new().verify().await?;
let vm = CodeVersionManager::new(&ctx.paths, platform);
let mut name = Column::new("Installation");
let mut command = Column::new("Command");
for version in vm.list() {
name.add_row(version.to_string());
command.add_row(version.get_command());
}
args.format
.print_table(OutputTable::new(vec![name, command]))
.ok();
Ok(0)
}
pub async fn uninstall(ctx: CommandContext, args: UninstallVersionArgs) -> Result<i32, AnyError> {
let platform = PreReqChecker::new().verify().await?;
let vm = CodeVersionManager::new(&ctx.paths, platform);
let version = RequestedVersion::try_from(args.name.as_str())?;
vm.uninstall(&version).await?;
ctx.log
.result(&format!("VS Code {} uninstalled successfully", version));
Ok(0)
}
fn print_now_using(log: &log::Logger, version: &RequestedVersion) {
log.result(&format!("Now using VS Code {}", version));
}

43
src/cli/src/constants.rs Normal file
View file

@ -0,0 +1,43 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use lazy_static::lazy_static;
pub const CONTROL_PORT: u16 = 31545;
pub const PROTOCOL_VERSION: u32 = 1;
pub const LAUNCHER_VERSION: Option<&'static str> = option_env!("LAUNCHER_VERSION");
pub const LAUNCHER_ASSET_NAME: Option<&'static str> =
if cfg!(all(target_os = "macos", target_arch = "x86_64")) {
Some("x86_64-apple-darwin-signed")
} else if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
Some("aarch64-apple-darwin-signed")
} else if cfg!(all(target_os = "windows", target_arch = "x86_64")) {
Some("x86_64-pc-windows-msvc-signed")
} else if cfg!(all(target_os = "windows", target_arch = "aarch64")) {
Some("aarch64-pc-windows-msvc-signed")
} else {
option_env!("LAUNCHER_ASSET_NAME")
};
pub const LAUNCHER_AI_KEY: Option<&'static str> = option_env!("LAUNCHER_AI_KEY");
pub const LAUNCHER_AI_ENDPOINT: Option<&'static str> = option_env!("LAUNCHER_AI_ENDPOINT");
pub const TUNNEL_SERVICE_USER_AGENT_ENV_VAR: &str = "TUNNEL_SERVICE_USER_AGENT";
pub fn get_default_user_agent() -> String {
format!(
"vscode-server-launcher/{}",
LAUNCHER_VERSION.unwrap_or("dev")
)
}
lazy_static! {
pub static ref TUNNEL_SERVICE_USER_AGENT: String =
match std::env::var(TUNNEL_SERVICE_USER_AGENT_ENV_VAR) {
Ok(ua) if !ua.is_empty() => format!("{} {}", ua, get_default_user_agent()),
_ => get_default_user_agent(),
};
}

8
src/cli/src/desktop.rs Normal file
View file

@ -0,0 +1,8 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
mod version_manager;
pub use version_manager::{prompt_to_install, CodeVersionManager, RequestedVersion};

View file

@ -0,0 +1,492 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{
fmt,
path::{Path, PathBuf},
};
use indicatif::ProgressBar;
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize, Serialize};
use tokio::fs::remove_dir_all;
use crate::{
options,
state::{LauncherPaths, PersistedState},
update_service::{unzip_downloaded_release, Platform, Release, TargetKind, UpdateService},
util::{
errors::{
wrap, AnyError, InvalidRequestedVersion, MissingEntrypointError,
NoInstallInUserProvidedPath, UserCancelledInstallation, WrappedError,
},
http,
input::{prompt_yn, ProgressBarReporter},
},
};
/// Parsed instance that a user can request.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "t", content = "c")]
pub enum RequestedVersion {
Quality(options::Quality),
Version {
version: String,
quality: options::Quality,
},
Commit {
commit: String,
quality: options::Quality,
},
Path(String),
}
lazy_static! {
static ref SEMVER_RE: Regex = Regex::new(r"^\d+\.\d+\.\d+(-insider)?$").unwrap();
static ref COMMIT_RE: Regex = Regex::new(r"^[a-z]+/[a-e0-f]{40}$").unwrap();
}
impl RequestedVersion {
pub fn get_command(&self) -> String {
match self {
RequestedVersion::Quality(quality) => {
format!("code version use {}", quality.get_machine_name())
}
RequestedVersion::Version { version, .. } => {
format!("code version use {}", version)
}
RequestedVersion::Commit { commit, quality } => {
format!("code version use {}/{}", quality.get_machine_name(), commit)
}
RequestedVersion::Path(path) => {
format!("code version use {}", path)
}
}
}
}
impl std::fmt::Display for RequestedVersion {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
RequestedVersion::Quality(quality) => write!(f, "{}", quality.get_capitalized_name()),
RequestedVersion::Version { version, .. } => {
write!(f, "{}", version)
}
RequestedVersion::Commit { commit, quality } => {
write!(f, "{}/{}", quality, commit)
}
RequestedVersion::Path(path) => write!(f, "{}", path),
}
}
}
impl TryFrom<&str> for RequestedVersion {
type Error = InvalidRequestedVersion;
fn try_from(s: &str) -> Result<Self, Self::Error> {
if let Ok(quality) = options::Quality::try_from(s) {
return Ok(RequestedVersion::Quality(quality));
}
if SEMVER_RE.is_match(s) {
return Ok(RequestedVersion::Version {
quality: if s.ends_with("-insider") {
options::Quality::Insiders
} else {
options::Quality::Stable
},
version: s.to_string(),
});
}
if Path::is_absolute(&PathBuf::from(s)) {
return Ok(RequestedVersion::Path(s.to_string()));
}
if COMMIT_RE.is_match(s) {
let idx = s.find('/').expect("expected a /");
if let Ok(quality) = options::Quality::try_from(&s[0..idx]) {
return Ok(RequestedVersion::Commit {
commit: s[idx + 1..].to_string(),
quality,
});
}
}
Err(InvalidRequestedVersion())
}
}
#[derive(Serialize, Deserialize, Clone, Default)]
struct Stored {
versions: Vec<RequestedVersion>,
current: usize,
}
pub struct CodeVersionManager {
state: PersistedState<Stored>,
platform: Platform,
storage_dir: PathBuf,
}
impl CodeVersionManager {
pub fn new(lp: &LauncherPaths, platform: Platform) -> Self {
CodeVersionManager {
state: PersistedState::new(lp.root().join("versions.json")),
storage_dir: lp.root().join("desktop"),
platform,
}
}
/// Sets the "version" as the persisted one for the user.
pub fn set_preferred_version(&self, version: &RequestedVersion) -> Result<(), AnyError> {
let mut stored = self.state.load();
if let Some(i) = stored.versions.iter().position(|v| v == version) {
stored.current = i;
} else {
stored.current = stored.versions.len();
stored.versions.push(version.clone());
}
self.state.save(stored)?;
Ok(())
}
/// Lists installed versions.
pub fn list(&self) -> Vec<RequestedVersion> {
self.state.load().versions
}
/// Uninstalls a previously installed version.
pub async fn uninstall(&self, version: &RequestedVersion) -> Result<(), AnyError> {
let mut stored = self.state.load();
if let Some(i) = stored.versions.iter().position(|v| v == version) {
if i > stored.current && i > 0 {
stored.current -= 1;
}
stored.versions.remove(i);
self.state.save(stored)?;
}
remove_dir_all(self.get_install_dir(version))
.await
.map_err(|e| wrap(e, "error deleting vscode directory"))?;
Ok(())
}
pub fn get_preferred_version(&self) -> RequestedVersion {
let stored = self.state.load();
stored
.versions
.get(stored.current)
.unwrap_or(&RequestedVersion::Quality(options::Quality::Stable))
.clone()
}
/// Installs the release for the given request. This always runs and does not
/// prompt, so you may want to use `try_get_entrypoint` first.
pub async fn install(
&self,
update_service: &UpdateService,
version: &RequestedVersion,
) -> Result<PathBuf, AnyError> {
let target_dir = self.get_install_dir(version);
let release = get_release_for_request(update_service, version, self.platform).await?;
install_release_into(update_service, &target_dir, &release).await?;
if let Some(p) = try_get_entrypoint(&target_dir).await {
return Ok(p);
}
Err(MissingEntrypointError().into())
}
/// Tries to get the entrypoint in the installed version, if one exists.
pub async fn try_get_entrypoint(&self, version: &RequestedVersion) -> Option<PathBuf> {
try_get_entrypoint(&self.get_install_dir(version)).await
}
fn get_install_dir(&self, version: &RequestedVersion) -> PathBuf {
let (name, quality) = match version {
RequestedVersion::Path(path) => return PathBuf::from(path),
RequestedVersion::Quality(quality) => (quality.get_machine_name(), quality),
RequestedVersion::Version {
quality,
version: number,
} => (number.as_str(), quality),
RequestedVersion::Commit { commit, quality } => (commit.as_str(), quality),
};
let mut dir = self.storage_dir.join(name);
if cfg!(target_os = "macos") {
dir.push(format!("{}.app", quality.get_app_name()))
}
dir
}
}
/// Shows a nice UI prompt to users asking them if they want to install the
/// requested version.
pub fn prompt_to_install(version: &RequestedVersion) -> Result<(), AnyError> {
if let RequestedVersion::Path(path) = version {
return Err(NoInstallInUserProvidedPath(path.clone()).into());
}
if !prompt_yn(&format!(
"VS Code {} is not installed yet, install it now?",
version
))? {
return Err(UserCancelledInstallation().into());
}
Ok(())
}
async fn get_release_for_request(
update_service: &UpdateService,
request: &RequestedVersion,
platform: Platform,
) -> Result<Release, WrappedError> {
match request {
RequestedVersion::Version {
quality,
version: number,
} => update_service
.get_release_by_semver_version(platform, TargetKind::Archive, *quality, number)
.await
.map_err(|e| wrap(e, "Could not get release")),
RequestedVersion::Commit { commit, quality } => Ok(Release {
platform,
commit: commit.clone(),
quality: *quality,
target: TargetKind::Archive,
}),
RequestedVersion::Quality(quality) => update_service
.get_latest_commit(platform, TargetKind::Archive, *quality)
.await
.map_err(|e| wrap(e, "Could not get release")),
_ => panic!("cannot get release info for a path"),
}
}
async fn install_release_into(
update_service: &UpdateService,
path: &Path,
release: &Release,
) -> Result<(), AnyError> {
let tempdir =
tempfile::tempdir().map_err(|e| wrap(e, "error creating temporary download dir"))?;
let save_path = tempdir.path().join("vscode");
let stream = update_service.get_download_stream(release).await?;
let pb = ProgressBar::new(1);
pb.set_message("Downloading...");
let progress = ProgressBarReporter::from(pb);
http::download_into_file(&save_path, progress, stream).await?;
let pb = ProgressBar::new(1);
pb.set_message("Unzipping...");
let progress = ProgressBarReporter::from(pb);
unzip_downloaded_release(&save_path, path, progress)?;
drop(tempdir);
Ok(())
}
/// Tries to find the binary entrypoint for VS Code installed in the path.
async fn try_get_entrypoint(path: &Path) -> Option<PathBuf> {
use tokio::sync::mpsc;
let (tx, mut rx) = mpsc::channel(1);
// Look for all the possible paths in parallel
for entry in DESKTOP_CLI_RELATIVE_PATH.split(',') {
let my_path = path.join(entry);
let my_tx = tx.clone();
tokio::spawn(async move {
if tokio::fs::metadata(&my_path).await.is_ok() {
my_tx.send(my_path).await.ok();
}
});
}
drop(tx); // drop so rx gets None if no sender emits
rx.recv().await
}
const DESKTOP_CLI_RELATIVE_PATH: &str = if cfg!(target_os = "macos") {
"Contents/Resources/app/bin/code"
} else if cfg!(target_os = "windows") {
"bin/code.cmd,bin/code-insiders.cmd,bin/code-exploration.cmd"
} else {
"bin/code,bin/code-insiders,bin/code-exploration"
};
#[cfg(test)]
mod tests {
use std::{
fs::{create_dir_all, File},
io::Write,
};
use super::*;
fn make_fake_vscode_install(path: &Path, quality: options::Quality) {
let bin = DESKTOP_CLI_RELATIVE_PATH
.split(',')
.next()
.expect("expected exe path");
let binary_file_path = if cfg!(target_os = "macos") {
path.join(format!("{}.app/{}", quality.get_app_name(), bin))
} else {
path.join(bin)
};
let parent_dir_path = binary_file_path.parent().expect("expected parent path");
create_dir_all(parent_dir_path).expect("expected to create parent dir");
let mut binary_file = File::create(binary_file_path).expect("expected to make file");
binary_file
.write_all(b"")
.expect("expected to write binary");
}
fn make_multiple_vscode_install() -> tempfile::TempDir {
let dir = tempfile::tempdir().expect("expected to make temp dir");
make_fake_vscode_install(&dir.path().join("desktop/stable"), options::Quality::Stable);
make_fake_vscode_install(&dir.path().join("desktop/1.68.2"), options::Quality::Stable);
dir
}
#[test]
fn test_requested_version_parses() {
assert_eq!(
RequestedVersion::try_from("1.2.3").unwrap(),
RequestedVersion::Version {
quality: options::Quality::Stable,
version: "1.2.3".to_string(),
}
);
assert_eq!(
RequestedVersion::try_from("1.2.3-insider").unwrap(),
RequestedVersion::Version {
quality: options::Quality::Insiders,
version: "1.2.3-insider".to_string(),
}
);
assert_eq!(
RequestedVersion::try_from("stable").unwrap(),
RequestedVersion::Quality(options::Quality::Stable)
);
assert_eq!(
RequestedVersion::try_from("insiders").unwrap(),
RequestedVersion::Quality(options::Quality::Insiders)
);
assert_eq!(
RequestedVersion::try_from("insiders/92fd228156aafeb326b23f6604028d342152313b")
.unwrap(),
RequestedVersion::Commit {
commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(),
quality: options::Quality::Insiders
}
);
assert_eq!(
RequestedVersion::try_from("stable/92fd228156aafeb326b23f6604028d342152313b").unwrap(),
RequestedVersion::Commit {
commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(),
quality: options::Quality::Stable
}
);
let exe = std::env::current_exe()
.expect("expected to get exe")
.to_string_lossy()
.to_string();
assert_eq!(
RequestedVersion::try_from((&exe).as_str()).unwrap(),
RequestedVersion::Path(exe),
);
}
#[test]
fn test_set_preferred_version() {
let dir = make_multiple_vscode_install();
let lp = LauncherPaths::new_without_replacements(dir.path().to_owned());
let vm1 = CodeVersionManager::new(&lp, Platform::LinuxARM64);
assert_eq!(
vm1.get_preferred_version(),
RequestedVersion::Quality(options::Quality::Stable)
);
vm1.set_preferred_version(&RequestedVersion::Quality(options::Quality::Exploration))
.expect("expected to store");
vm1.set_preferred_version(&RequestedVersion::Quality(options::Quality::Insiders))
.expect("expected to store");
assert_eq!(
vm1.get_preferred_version(),
RequestedVersion::Quality(options::Quality::Insiders)
);
let vm2 = CodeVersionManager::new(&lp, Platform::LinuxARM64);
assert_eq!(
vm2.get_preferred_version(),
RequestedVersion::Quality(options::Quality::Insiders)
);
assert_eq!(
vm2.list(),
vec![
RequestedVersion::Quality(options::Quality::Exploration),
RequestedVersion::Quality(options::Quality::Insiders)
]
);
}
#[tokio::test]
async fn test_gets_entrypoint() {
let dir = make_multiple_vscode_install();
let lp = LauncherPaths::new_without_replacements(dir.path().to_owned());
let vm = CodeVersionManager::new(&lp, Platform::LinuxARM64);
assert!(vm
.try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Stable))
.await
.is_some());
assert!(vm
.try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Exploration))
.await
.is_none());
}
#[tokio::test]
async fn test_uninstall() {
let dir = make_multiple_vscode_install();
let lp = LauncherPaths::new_without_replacements(dir.path().to_owned());
let vm = CodeVersionManager::new(&lp, Platform::LinuxARM64);
vm.uninstall(&RequestedVersion::Quality(options::Quality::Stable))
.await
.expect("expected to uninsetall");
assert!(vm
.try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Stable))
.await
.is_none());
}
}

19
src/cli/src/lib.rs Normal file
View file

@ -0,0 +1,19 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
// todo: we should reduce the exported surface area over time as things are
// moved into a common CLI
pub mod auth;
pub mod constants;
#[macro_use]
pub mod log;
pub mod commands;
pub mod desktop;
pub mod options;
pub mod tunnels;
pub mod state;
pub mod update;
pub mod update_service;
pub mod util;

389
src/cli/src/log.rs Normal file
View file

@ -0,0 +1,389 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use chrono::Local;
use opentelemetry::{
sdk::trace::Tracer,
trace::{SpanBuilder, Tracer as TraitTracer},
};
use std::fmt;
use std::{env, path::Path, sync::Arc};
use std::{
io::Write,
sync::atomic::{AtomicU32, Ordering},
};
const NO_COLOR_ENV: &str = "NO_COLOR";
static INSTANCE_COUNTER: AtomicU32 = AtomicU32::new(0);
// Gets a next incrementing number that can be used in logs
pub fn next_counter() -> u32 {
INSTANCE_COUNTER.fetch_add(1, Ordering::SeqCst)
}
// Log level
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug)]
pub enum Level {
Trace = 0,
Debug,
Info,
Warn,
Error,
Critical,
Off,
}
impl Default for Level {
fn default() -> Self {
Level::Info
}
}
impl fmt::Display for Level {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Level::Critical => write!(f, "critical"),
Level::Debug => write!(f, "debug"),
Level::Error => write!(f, "error"),
Level::Info => write!(f, "info"),
Level::Off => write!(f, "off"),
Level::Trace => write!(f, "trace"),
Level::Warn => write!(f, "warn"),
}
}
}
impl Level {
pub fn name(&self) -> Option<&str> {
match self {
Level::Trace => Some("trace"),
Level::Debug => Some("debug"),
Level::Info => Some("info"),
Level::Warn => Some("warn"),
Level::Error => Some("error"),
Level::Critical => Some("critical"),
Level::Off => None,
}
}
pub fn color_code(&self) -> Option<&str> {
if env::var(NO_COLOR_ENV).is_ok() || !atty::is(atty::Stream::Stdout) {
return None;
}
match self {
Level::Trace => None,
Level::Debug => Some("\x1b[36m"),
Level::Info => Some("\x1b[35m"),
Level::Warn => Some("\x1b[33m"),
Level::Error => Some("\x1b[31m"),
Level::Critical => Some("\x1b[31m"),
Level::Off => None,
}
}
pub fn to_u8(self) -> u8 {
self as u8
}
}
pub fn new_tunnel_prefix() -> String {
format!("[tunnel.{}]", next_counter())
}
pub fn new_code_server_prefix() -> String {
format!("[codeserver.{}]", next_counter())
}
pub fn new_rpc_prefix() -> String {
format!("[rpc.{}]", next_counter())
}
// Base logger implementation
#[derive(Clone)]
pub struct Logger {
tracer: Tracer,
sink: Vec<Box<dyn LogSink>>,
prefix: Option<String>,
}
// Copy trick from https://stackoverflow.com/a/30353928
pub trait LogSinkClone {
fn clone_box(&self) -> Box<dyn LogSink>;
}
impl<T> LogSinkClone for T
where
T: 'static + LogSink + Clone,
{
fn clone_box(&self) -> Box<dyn LogSink> {
Box::new(self.clone())
}
}
pub trait LogSink: LogSinkClone + Sync + Send {
fn write_log(&self, level: Level, prefix: &str, message: &str);
fn write_result(&self, message: &str);
}
impl Clone for Box<dyn LogSink> {
fn clone(&self) -> Box<dyn LogSink> {
self.clone_box()
}
}
#[derive(Clone)]
pub struct StdioLogSink {
level: Level,
}
impl LogSink for StdioLogSink {
fn write_log(&self, level: Level, prefix: &str, message: &str) {
if level < self.level {
return;
}
emit(level, prefix, message);
}
fn write_result(&self, message: &str) {
println!("{}", message);
}
}
#[derive(Clone)]
pub struct FileLogSink {
level: Level,
file: Arc<std::sync::Mutex<std::fs::File>>,
}
impl FileLogSink {
pub fn new(level: Level, path: &Path) -> std::io::Result<Self> {
let file = std::fs::File::create(path)?;
Ok(Self {
level,
file: Arc::new(std::sync::Mutex::new(file)),
})
}
}
impl LogSink for FileLogSink {
fn write_log(&self, level: Level, prefix: &str, message: &str) {
if level < self.level {
return;
}
let line = format(level, prefix, message);
// ignore any errors, not much we can do if logging fails...
self.file.lock().unwrap().write_all(line.as_bytes()).ok();
}
fn write_result(&self, _message: &str) {}
}
impl Logger {
pub fn new(tracer: Tracer, level: Level) -> Self {
Self {
tracer,
sink: vec![Box::new(StdioLogSink { level })],
prefix: None,
}
}
pub fn span(&self, name: &str) -> SpanBuilder {
self.tracer.span_builder(format!("serverlauncher/{}", name))
}
pub fn tracer(&self) -> &Tracer {
&self.tracer
}
pub fn emit(&self, level: Level, message: &str) {
let prefix = self.prefix.as_deref().unwrap_or("");
for sink in &self.sink {
sink.write_log(level, prefix, message);
}
}
pub fn result(&self, message: &str) {
for sink in &self.sink {
sink.write_result(message);
}
}
pub fn prefixed(&self, prefix: &str) -> Logger {
Logger {
prefix: Some(match &self.prefix {
Some(p) => format!("{}{} ", p, prefix),
None => format!("{} ", prefix),
}),
..self.clone()
}
}
/// Creates a new logger with the additional log sink added.
pub fn tee<T>(&self, sink: T) -> Logger
where
T: LogSink + 'static,
{
let mut new_sinks = self.sink.clone();
new_sinks.push(Box::new(sink));
Logger {
sink: new_sinks,
..self.clone()
}
}
pub fn get_download_logger<'a>(&'a self, prefix: &'static str) -> DownloadLogger<'a> {
DownloadLogger {
prefix,
logger: self,
}
}
}
pub struct DownloadLogger<'a> {
prefix: &'static str,
logger: &'a Logger,
}
impl<'a> crate::util::io::ReportCopyProgress for DownloadLogger<'a> {
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) {
if total_bytes > 0 {
self.logger.emit(
Level::Trace,
&format!(
"{} {}/{} ({:.0}%)",
self.prefix,
bytes_so_far,
total_bytes,
(bytes_so_far as f64 / total_bytes as f64) * 100.0,
),
);
} else {
self.logger.emit(
Level::Trace,
&format!("{} {}/{}", self.prefix, bytes_so_far, total_bytes,),
);
}
}
}
pub fn format(level: Level, prefix: &str, message: &str) -> String {
let current = Local::now();
let timestamp = current.format("%Y-%m-%d %H:%M:%S").to_string();
let name = level.name().unwrap();
if let Some(c) = level.color_code() {
format!(
"\x1b[2m[{}]\x1b[0m {}{}\x1b[0m {}{}\n",
timestamp, c, name, prefix, message
)
} else {
format!("[{}] {} {}{}\n", timestamp, name, prefix, message)
}
}
pub fn emit(level: Level, prefix: &str, message: &str) {
let line = format(level, prefix, message);
if level == Level::Trace {
print!("\x1b[2m{}\x1b[0m", line);
} else {
print!("{}", line);
}
}
#[macro_export]
macro_rules! error {
($logger:expr, $str:expr) => {
$logger.emit(log::Level::Error, $str)
};
($logger:expr, $($fmt:expr),+) => {
$logger.emit(log::Level::Error, &format!($($fmt),+))
};
}
#[macro_export]
macro_rules! trace {
($logger:expr, $str:expr) => {
$logger.emit(log::Level::Trace, $str)
};
($logger:expr, $($fmt:expr),+) => {
$logger.emit(log::Level::Trace, &format!($($fmt),+))
};
}
#[macro_export]
macro_rules! debug {
($logger:expr, $str:expr) => {
$logger.emit(log::Level::Debug, $str)
};
($logger:expr, $($fmt:expr),+) => {
$logger.emit(log::Level::Debug, &format!($($fmt),+))
};
}
#[macro_export]
macro_rules! info {
($logger:expr, $str:expr) => {
$logger.emit(log::Level::Info, $str)
};
($logger:expr, $($fmt:expr),+) => {
$logger.emit(log::Level::Info, &format!($($fmt),+))
};
}
#[macro_export]
macro_rules! warning {
($logger:expr, $str:expr) => {
$logger.emit(log::Level::Warn, $str)
};
($logger:expr, $($fmt:expr),+) => {
$logger.emit(log::Level::Warn, &format!($($fmt),+))
};
}
#[macro_export]
macro_rules! span {
($logger:expr, $span:expr, $func:expr) => {{
use opentelemetry::trace::TraceContextExt;
let span = $span.start($logger.tracer());
let cx = opentelemetry::Context::current_with_span(span);
let guard = cx.clone().attach();
let t = $func;
if let Err(e) = &t {
cx.span().record_error(e);
}
std::mem::drop(guard);
t
}};
}
#[macro_export]
macro_rules! spanf {
($logger:expr, $span:expr, $func:expr) => {{
use opentelemetry::trace::{FutureExt, TraceContextExt};
let span = $span.start($logger.tracer());
let cx = opentelemetry::Context::current_with_span(span);
let t = $func.with_context(cx.clone()).await;
if let Err(e) = &t {
cx.span().record_error(e);
}
cx.span().end();
t
}};
}

104
src/cli/src/options.rs Normal file
View file

@ -0,0 +1,104 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(clap::ArgEnum, Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Quality {
#[serde(rename = "stable")]
Stable,
#[serde(rename = "exploration")]
Exploration,
#[serde(other)]
Insiders,
}
impl Quality {
/// Lowercased name in paths and protocol
pub fn get_machine_name(&self) -> &'static str {
match self {
Quality::Insiders => "insiders",
Quality::Exploration => "exploration",
Quality::Stable => "stable",
}
}
/// Uppercased display name for humans
pub fn get_capitalized_name(&self) -> &'static str {
match self {
Quality::Insiders => "Insiders",
Quality::Exploration => "Exploration",
Quality::Stable => "Stable",
}
}
pub fn get_app_name(&self) -> &'static str {
match self {
Quality::Insiders => "Visual Studio Code Insiders",
Quality::Exploration => "Visual Studio Code Exploration",
Quality::Stable => "Visual Studio Code",
}
}
#[cfg(target_os = "windows")]
pub fn server_entrypoint(&self) -> &'static str {
match self {
Quality::Insiders => "code-server-insiders.cmd",
Quality::Exploration => "code-server-exploration.cmd",
Quality::Stable => "code-server.cmd",
}
}
#[cfg(not(target_os = "windows"))]
pub fn server_entrypoint(&self) -> &'static str {
match self {
Quality::Insiders => "code-server-insiders",
Quality::Exploration => "code-server-exploration",
Quality::Stable => "code-server",
}
}
}
impl fmt::Display for Quality {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.get_capitalized_name())
}
}
impl TryFrom<&str> for Quality {
type Error = String;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"stable" => Ok(Quality::Stable),
"insiders" => Ok(Quality::Insiders),
"exploration" => Ok(Quality::Exploration),
_ => Err(format!(
"Unknown quality: {}. Must be one of stable, insiders, or exploration.",
s
)),
}
}
}
#[derive(clap::ArgEnum, Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TelemetryLevel {
Off,
Crash,
Error,
All,
}
impl fmt::Display for TelemetryLevel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TelemetryLevel::Off => write!(f, "off"),
TelemetryLevel::Crash => write!(f, "crash"),
TelemetryLevel::Error => write!(f, "error"),
TelemetryLevel::All => write!(f, "all"),
}
}
}

152
src/cli/src/state.rs Normal file
View file

@ -0,0 +1,152 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
extern crate dirs;
use std::{
fs::{create_dir, read_to_string, remove_dir_all, write},
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
use serde::{de::DeserializeOwned, Serialize};
use crate::util::errors::{wrap, AnyError, NoHomeForLauncherError, WrappedError};
const HOME_DIR_ALTS: [&str; 2] = ["$HOME", "~"];
#[derive(Clone)]
pub struct LauncherPaths {
root: PathBuf,
}
struct PersistedStateContainer<T>
where
T: Clone + Serialize + DeserializeOwned + Default,
{
path: PathBuf,
state: Option<T>,
}
impl<T> PersistedStateContainer<T>
where
T: Clone + Serialize + DeserializeOwned + Default,
{
fn load_or_get(&mut self) -> T {
if let Some(state) = &self.state {
return state.clone();
}
let state = if let Ok(s) = read_to_string(&self.path) {
serde_json::from_str::<T>(&s).unwrap_or_default()
} else {
T::default()
};
self.state = Some(state.clone());
state
}
fn save(&mut self, state: T) -> Result<(), WrappedError> {
let s = serde_json::to_string(&state).unwrap();
self.state = Some(state);
write(&self.path, s).map_err(|e| {
wrap(
e,
format!("error saving launcher state into {}", self.path.display()),
)
})
}
}
/// Container that holds some state value that is persisted to disk.
#[derive(Clone)]
pub struct PersistedState<T>
where
T: Clone + Serialize + DeserializeOwned + Default,
{
container: Arc<Mutex<PersistedStateContainer<T>>>,
}
impl<T> PersistedState<T>
where
T: Clone + Serialize + DeserializeOwned + Default,
{
/// Creates a new state container that persists to the given path.
pub fn new(path: PathBuf) -> PersistedState<T> {
PersistedState {
container: Arc::new(Mutex::new(PersistedStateContainer { path, state: None })),
}
}
/// Loads persisted state.
pub fn load(&self) -> T {
self.container.lock().unwrap().load_or_get()
}
/// Saves persisted state.
pub fn save(&self, state: T) -> Result<(), WrappedError> {
self.container.lock().unwrap().save(state)
}
/// Mutates persisted state.
pub fn update_with<V, R>(
&self,
v: V,
mutator: fn(v: V, state: &mut T) -> R,
) -> Result<R, WrappedError> {
let mut container = self.container.lock().unwrap();
let mut state = container.load_or_get();
let r = mutator(v, &mut state);
container.save(state).map(|_| r)
}
}
impl LauncherPaths {
pub fn new(root: &Option<String>) -> Result<LauncherPaths, AnyError> {
let root = root.as_deref().unwrap_or("~/.vscode-cli");
let mut replaced = root.to_owned();
for token in HOME_DIR_ALTS {
if root.contains(token) {
if let Some(home) = dirs::home_dir() {
replaced = root.replace(token, &home.to_string_lossy())
} else {
return Err(AnyError::from(NoHomeForLauncherError()));
}
}
}
if !Path::new(&replaced).exists() {
create_dir(&replaced)
.map_err(|e| wrap(e, format!("error creating directory {}", &replaced)))?;
}
Ok(LauncherPaths::new_without_replacements(PathBuf::from(
replaced,
)))
}
pub fn new_without_replacements(root: PathBuf) -> LauncherPaths {
LauncherPaths { root }
}
/// Root directory for the server launcher
pub fn root(&self) -> &Path {
&self.root
}
/// Removes the launcher data directory.
pub fn remove(&self) -> Result<(), WrappedError> {
remove_dir_all(&self.root).map_err(|e| {
wrap(
e,
format!(
"error removing launcher data directory {}",
self.root.display()
),
)
})
}
}

25
src/cli/src/tunnels.rs Normal file
View file

@ -0,0 +1,25 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
pub mod code_server;
pub mod dev_tunnels;
pub mod legal;
pub mod paths;
mod control_server;
mod name_generator;
mod port_forwarder;
mod protocol;
#[cfg_attr(unix, path = "tunnels/server_bridge_unix.rs")]
#[cfg_attr(windows, path = "tunnels/server_bridge_windows.rs")]
mod server_bridge;
mod service;
#[cfg(target_os = "windows")]
mod service_windows;
pub use control_server::serve;
pub use service::{
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,
};

View file

@ -0,0 +1,756 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use super::paths::{InstalledServer, LastUsedServers, ServerPaths};
use crate::options::{Quality, TelemetryLevel};
use crate::state::LauncherPaths;
use crate::update_service::{
unzip_downloaded_release, Platform, Release, TargetKind, UpdateService,
};
use crate::util::command::{capture_command, kill_tree};
use crate::util::errors::{
wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError,
};
use crate::util::http;
use crate::util::io::SilentCopyProgress;
use crate::util::machine::process_exists;
use crate::{debug, info, log, span, spanf, trace, warning};
use lazy_static::lazy_static;
use opentelemetry::KeyValue;
use regex::Regex;
use serde::Deserialize;
use std::fs;
use std::fs::File;
use std::io::{ErrorKind, Write};
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::fs::remove_file;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::oneshot::Receiver;
use tokio::time::{interval, timeout};
use uuid::Uuid;
lazy_static! {
static ref LISTENING_PORT_RE: Regex =
Regex::new(r"Extension host agent listening on (.+)").unwrap();
static ref WEB_UI_RE: Regex = Regex::new(r"Web UI available at (.+)").unwrap();
}
const MAX_RETAINED_SERVERS: usize = 5;
#[derive(Clone, Debug, Default)]
pub struct CodeServerArgs {
pub host: Option<String>,
pub port: Option<u16>,
pub socket_path: Option<String>,
// common argument
pub telemetry_level: Option<TelemetryLevel>,
pub log: Option<log::Level>,
pub accept_server_license_terms: bool,
pub verbose: bool,
// extension management
pub install_extensions: Vec<String>,
pub uninstall_extensions: Vec<String>,
pub list_extensions: bool,
pub show_versions: bool,
pub category: Option<String>,
pub pre_release: bool,
pub force: bool,
pub start_server: bool,
// connection tokens
pub connection_token: Option<String>,
pub connection_token_file: Option<String>,
pub without_connection_token: bool,
}
impl CodeServerArgs {
pub fn log_level(&self) -> log::Level {
if self.verbose {
log::Level::Trace
} else {
self.log.unwrap_or(log::Level::Info)
}
}
pub fn telemetry_disabled(&self) -> bool {
self.telemetry_level == Some(TelemetryLevel::Off)
}
pub fn command_arguments(&self) -> Vec<String> {
let mut args = Vec::new();
if let Some(i) = &self.socket_path {
args.push(format!("--socket-path={}", i));
} else {
if let Some(i) = &self.host {
args.push(format!("--host={}", i));
}
if let Some(i) = &self.port {
args.push(format!("--port={}", i));
}
}
if let Some(i) = &self.connection_token {
args.push(format!("--connection-token={}", i));
}
if let Some(i) = &self.connection_token_file {
args.push(format!("--connection-token-file={}", i));
}
if self.without_connection_token {
args.push(String::from("--without-connection-token"));
}
if self.accept_server_license_terms {
args.push(String::from("--accept-server-license-terms"));
}
if let Some(i) = self.telemetry_level {
args.push(format!("--telemetry-level={}", i));
}
if let Some(i) = self.log {
args.push(format!("--log={}", i));
}
for extension in &self.install_extensions {
args.push(format!("--install-extension={}", extension));
}
if !&self.install_extensions.is_empty() {
if self.pre_release {
args.push(String::from("--pre-release"));
}
if self.force {
args.push(String::from("--force"));
}
}
for extension in &self.uninstall_extensions {
args.push(format!("--uninstall-extension={}", extension));
}
if self.list_extensions {
args.push(String::from("--list-extensions"));
if self.show_versions {
args.push(String::from("--show-versions"));
}
if let Some(i) = &self.category {
args.push(format!("--category={}", i));
}
}
if self.start_server {
args.push(String::from("--start-server"));
}
args
}
}
/// Base server params that can be `resolve()`d to a `ResolvedServerParams`.
/// Doing so fetches additional information like a commit ID if previously
/// unspecified.
pub struct ServerParamsRaw {
pub commit_id: Option<String>,
pub quality: Quality,
pub code_server_args: CodeServerArgs,
pub headless: bool,
pub platform: Platform,
}
/// Server params that can be used to start a VS Code server.
pub struct ResolvedServerParams {
pub release: Release,
pub code_server_args: CodeServerArgs,
}
impl ResolvedServerParams {
fn as_installed_server(&self) -> InstalledServer {
InstalledServer {
commit: self.release.commit.clone(),
quality: self.release.quality,
headless: self.release.target == TargetKind::Server,
}
}
}
impl ServerParamsRaw {
pub async fn resolve(self, log: &log::Logger) -> Result<ResolvedServerParams, AnyError> {
Ok(ResolvedServerParams {
release: self.get_or_fetch_commit_id(log).await?,
code_server_args: self.code_server_args,
})
}
async fn get_or_fetch_commit_id(&self, log: &log::Logger) -> Result<Release, AnyError> {
let target = match self.headless {
true => TargetKind::Server,
false => TargetKind::Web,
};
if let Some(c) = &self.commit_id {
return Ok(Release {
commit: c.clone(),
quality: self.quality,
target,
platform: self.platform,
});
}
UpdateService::new(log.clone(), reqwest::Client::new())
.get_latest_commit(self.platform, target, self.quality)
.await
}
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
struct UpdateServerVersion {
pub name: String,
pub version: String,
pub product_version: String,
pub timestamp: i64,
}
/// Code server listening on a port address.
pub struct SocketCodeServer {
pub commit_id: String,
pub socket: PathBuf,
pub origin: CodeServerOrigin,
}
/// Code server listening on a socket address.
pub struct PortCodeServer {
pub commit_id: String,
pub port: u16,
pub origin: CodeServerOrigin,
}
/// A server listening on any address/location.
pub enum AnyCodeServer {
Socket(SocketCodeServer),
Port(PortCodeServer),
}
// impl AnyCodeServer {
// pub fn origin(&mut self) -> &mut CodeServerOrigin {
// match self {
// AnyCodeServer::Socket(p) => &mut p.origin,
// AnyCodeServer::Port(p) => &mut p.origin,
// }
// }
// }
pub enum CodeServerOrigin {
/// A new code server, that opens the barrier when it exits.
New(Child),
/// An existing code server with a PID.
Existing(u32),
}
impl CodeServerOrigin {
pub async fn wait_for_exit(&mut self) {
match self {
CodeServerOrigin::New(child) => {
child.wait().await.ok();
}
CodeServerOrigin::Existing(pid) => {
let mut interval = interval(Duration::from_secs(30));
while process_exists(*pid) {
interval.tick().await;
}
}
}
}
pub async fn kill(&mut self) {
match self {
CodeServerOrigin::New(child) => {
child.kill().await.ok();
}
CodeServerOrigin::Existing(pid) => {
kill_tree(*pid).await.ok();
}
}
}
}
async fn check_and_create_dir(path: &Path) -> Result<(), WrappedError> {
tokio::fs::create_dir_all(path)
.await
.map_err(|e| wrap(e, "error creating server directory"))?;
Ok(())
}
async fn install_server_if_needed(
log: &log::Logger,
paths: &ServerPaths,
release: &Release,
) -> Result<(), AnyError> {
if paths.executable.exists() {
info!(
log,
"Found existing installation at {}",
paths.server_dir.display()
);
return Ok(());
}
let tar_file_path = spanf!(
log,
log.span("server.download"),
download_server(&paths.server_dir, release, log)
)?;
span!(
log,
log.span("server.extract"),
install_server(&tar_file_path, paths, log)
)?;
Ok(())
}
async fn download_server(
path: &Path,
release: &Release,
log: &log::Logger,
) -> Result<PathBuf, AnyError> {
let response = UpdateService::new(log.clone(), reqwest::Client::new())
.get_download_stream(release)
.await?;
let mut save_path = path.to_owned();
let fname = response
.url()
.path_segments()
.and_then(|segments| segments.last())
.and_then(|name| if name.is_empty() { None } else { Some(name) })
.unwrap_or("tmp.zip");
info!(
log,
"Downloading VS Code server {} -> {}",
response.url(),
save_path.display()
);
save_path.push(fname);
http::download_into_file(
&save_path,
log.get_download_logger("server download progress:"),
response,
)
.await?;
Ok(save_path)
}
fn install_server(
compressed_file: &Path,
paths: &ServerPaths,
log: &log::Logger,
) -> Result<(), AnyError> {
info!(log, "Setting up server...");
unzip_downloaded_release(compressed_file, &paths.server_dir, SilentCopyProgress())?;
match fs::remove_file(&compressed_file) {
Ok(()) => {}
Err(e) => {
if e.kind() != ErrorKind::NotFound {
return Err(AnyError::from(wrap(e, "error removing downloaded file")));
}
}
}
if !paths.executable.exists() {
return Err(AnyError::from(MissingEntrypointError()));
}
Ok(())
}
/// Ensures the given list of extensions are installed on the running server.
async fn do_extension_install_on_running_server(
start_script_path: &Path,
extensions: &[String],
log: &log::Logger,
) -> Result<(), AnyError> {
if extensions.is_empty() {
return Ok(());
}
debug!(log, "Installing extensions...");
let command = format!(
"{} {}",
start_script_path.display(),
extensions
.iter()
.map(|s| get_extensions_flag(s))
.collect::<Vec<String>>()
.join(" ")
);
let result = capture_command("bash", &["-c", &command]).await?;
if !result.status.success() {
Err(AnyError::from(ExtensionInstallFailed(
String::from_utf8_lossy(&result.stderr).to_string(),
)))
} else {
Ok(())
}
}
pub struct ServerBuilder<'a> {
logger: &'a log::Logger,
server_params: &'a ResolvedServerParams,
last_used: LastUsedServers<'a>,
server_paths: ServerPaths,
}
impl<'a> ServerBuilder<'a> {
pub fn new(
logger: &'a log::Logger,
server_params: &'a ResolvedServerParams,
launcher_paths: &'a LauncherPaths,
) -> Self {
Self {
logger,
server_params,
last_used: LastUsedServers::new(launcher_paths),
server_paths: server_params
.as_installed_server()
.server_paths(launcher_paths),
}
}
/// Gets any already-running server from this directory.
pub async fn get_running(&self) -> Result<Option<AnyCodeServer>, AnyError> {
info!(
self.logger,
"Checking {} and {} for a running server...",
self.server_paths.logfile.display(),
self.server_paths.pidfile.display()
);
let pid = match self.server_paths.get_running_pid() {
Some(pid) => pid,
None => return Ok(None),
};
info!(self.logger, "Found running server (pid={})", pid);
if !Path::new(&self.server_paths.logfile).exists() {
warning!(self.logger, "VS Code Server is running but its logfile is missing. Don't delete the VS Code Server manually, run the command 'code-server prune'.");
return Ok(None);
}
do_extension_install_on_running_server(
&self.server_paths.executable,
&self.server_params.code_server_args.install_extensions,
self.logger,
)
.await?;
let origin = CodeServerOrigin::Existing(pid);
let contents = fs::read_to_string(&self.server_paths.logfile)
.expect("Something went wrong reading log file");
if let Some(port) = parse_port_from(&contents) {
Ok(Some(AnyCodeServer::Port(PortCodeServer {
commit_id: self.server_params.release.commit.to_owned(),
port,
origin,
})))
} else if let Some(socket) = parse_socket_from(&contents) {
Ok(Some(AnyCodeServer::Socket(SocketCodeServer {
commit_id: self.server_params.release.commit.to_owned(),
socket,
origin,
})))
} else {
Ok(None)
}
}
/// Ensures the server is set up in the configured directory.
pub async fn setup(&self) -> Result<(), AnyError> {
debug!(self.logger, "Installing and setting up VS Code Server...");
check_and_create_dir(&self.server_paths.server_dir).await?;
install_server_if_needed(self.logger, &self.server_paths, &self.server_params.release)
.await?;
debug!(self.logger, "Server setup complete");
match self.last_used.add(self.server_params.as_installed_server()) {
Err(e) => warning!(self.logger, "Error adding server to last used: {}", e),
Ok(count) if count > MAX_RETAINED_SERVERS => {
if let Err(e) = self.last_used.trim(self.logger, MAX_RETAINED_SERVERS) {
warning!(self.logger, "Error trimming old servers: {}", e);
}
}
Ok(_) => {}
}
Ok(())
}
pub async fn listen_on_default_socket(&self) -> Result<SocketCodeServer, AnyError> {
let requested_file = if cfg!(target_os = "windows") {
PathBuf::from(format!(r"\\.\pipe\vscode-server-{}", Uuid::new_v4()))
} else {
std::env::temp_dir().join(format!("vscode-server-{}", Uuid::new_v4()))
};
self.listen_on_socket(&requested_file).await
}
pub async fn listen_on_socket(&self, socket: &Path) -> Result<SocketCodeServer, AnyError> {
Ok(spanf!(
self.logger,
self.logger.span("server.start").with_attributes(vec! {
KeyValue::new("commit_id", self.server_params.release.commit.to_string()),
KeyValue::new("quality", format!("{}", self.server_params.release.quality)),
}),
self._listen_on_socket(socket)
)?)
}
async fn _listen_on_socket(&self, socket: &Path) -> Result<SocketCodeServer, AnyError> {
remove_file(&socket).await.ok(); // ignore any error if it doesn't exist
let mut cmd = self.get_base_command();
cmd.arg("--start-server")
.arg("--without-connection-token")
.arg("--enable-remote-auto-shutdown")
.arg(format!("--socket-path={}", socket.display()));
let child = self.spawn_server_process(cmd)?;
let log_file = self.get_logfile()?;
let plog = self.logger.prefixed(&log::new_code_server_prefix());
let (mut origin, listen_rx) =
monitor_server::<SocketMatcher, PathBuf>(child, Some(log_file), plog, false);
let socket = match timeout(Duration::from_secs(8), listen_rx).await {
Err(e) => {
origin.kill().await;
Err(wrap(e, "timed out looking for socket"))
}
Ok(Err(e)) => {
origin.kill().await;
Err(wrap(e, "server exited without writing socket"))
}
Ok(Ok(socket)) => Ok(socket),
}?;
info!(self.logger, "Server started");
Ok(SocketCodeServer {
commit_id: self.server_params.release.commit.to_owned(),
socket,
origin,
})
}
/// Starts with a given opaque set of args. Does not set up any port or
/// socket, but does return one if present, in the form of a channel.
pub async fn start_opaque_with_args<M, R>(
&self,
args: &[String],
) -> Result<(CodeServerOrigin, Receiver<R>), AnyError>
where
M: ServerOutputMatcher<R>,
R: 'static + Send + std::fmt::Debug,
{
let mut cmd = self.get_base_command();
cmd.args(args);
let child = self.spawn_server_process(cmd)?;
let plog = self.logger.prefixed(&log::new_code_server_prefix());
Ok(monitor_server::<M, R>(child, None, plog, true))
}
fn spawn_server_process(&self, mut cmd: Command) -> Result<Child, AnyError> {
info!(self.logger, "Starting server...");
debug!(self.logger, "Starting server with command... {:?}", cmd);
let child = cmd
.stderr(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.spawn()
.map_err(|e| wrap(e, "error spawning server"))?;
self.server_paths
.write_pid(child.id().expect("expected server to have pid"))?;
Ok(child)
}
fn get_logfile(&self) -> Result<File, WrappedError> {
File::create(&self.server_paths.logfile).map_err(|e| {
wrap(
e,
format!(
"error creating log file {}",
self.server_paths.logfile.display()
),
)
})
}
fn get_base_command(&self) -> Command {
let mut cmd = Command::new(&self.server_paths.executable);
cmd.stdin(std::process::Stdio::null())
.args(self.server_params.code_server_args.command_arguments());
cmd
}
}
fn monitor_server<M, R>(
mut child: Child,
log_file: Option<File>,
plog: log::Logger,
write_directly: bool,
) -> (CodeServerOrigin, Receiver<R>)
where
M: ServerOutputMatcher<R>,
R: 'static + Send + std::fmt::Debug,
{
let stdout = child
.stdout
.take()
.expect("child did not have a handle to stdout");
let stderr = child
.stderr
.take()
.expect("child did not have a handle to stdout");
let (listen_tx, listen_rx) = tokio::sync::oneshot::channel();
// Handle stderr and stdout in a separate task. Initially scan lines looking
// for the listening port. Afterwards, just scan and write out to the file.
tokio::spawn(async move {
let mut stdout_reader = BufReader::new(stdout).lines();
let mut stderr_reader = BufReader::new(stderr).lines();
let write_line = |line: &str| -> std::io::Result<()> {
if let Some(mut f) = log_file.as_ref() {
f.write_all(line.as_bytes())?;
f.write_all(&[b'\n'])?;
}
if write_directly {
println!("{}", line);
} else {
trace!(plog, line);
}
Ok(())
};
loop {
let line = tokio::select! {
l = stderr_reader.next_line() => l,
l = stdout_reader.next_line() => l,
};
match line {
Err(e) => {
trace!(plog, "error reading from stdout/stderr: {}", e);
return;
}
Ok(None) => break,
Ok(Some(l)) => {
write_line(&l).ok();
if let Some(listen_on) = M::match_line(&l) {
trace!(plog, "parsed location: {:?}", listen_on);
listen_tx.send(listen_on).ok();
break;
}
}
}
}
loop {
let line = tokio::select! {
l = stderr_reader.next_line() => l,
l = stdout_reader.next_line() => l,
};
match line {
Err(e) => {
trace!(plog, "error reading from stdout/stderr: {}", e);
break;
}
Ok(None) => break,
Ok(Some(l)) => {
write_line(&l).ok();
}
}
}
});
let origin = CodeServerOrigin::New(child);
(origin, listen_rx)
}
fn get_extensions_flag(extension_id: &str) -> String {
format!("--install-extension={}", extension_id)
}
/// A type that can be used to scan stdout from the VS Code server. Returns
/// some other type that, in turn, is returned from starting the server.
pub trait ServerOutputMatcher<R>
where
R: Send,
{
fn match_line(line: &str) -> Option<R>;
}
/// Parses a line like "Extension host agent listening on /tmp/foo.sock"
struct SocketMatcher();
impl ServerOutputMatcher<PathBuf> for SocketMatcher {
fn match_line(line: &str) -> Option<PathBuf> {
parse_socket_from(line)
}
}
/// Parses a line like "Extension host agent listening on 9000"
pub struct PortMatcher();
impl ServerOutputMatcher<u16> for PortMatcher {
fn match_line(line: &str) -> Option<u16> {
parse_port_from(line)
}
}
/// Parses a line like "Web UI available at http://localhost:9000/?tkn=..."
pub struct WebUiMatcher();
impl ServerOutputMatcher<reqwest::Url> for WebUiMatcher {
fn match_line(line: &str) -> Option<reqwest::Url> {
WEB_UI_RE.captures(line).and_then(|cap| {
cap.get(1)
.and_then(|uri| reqwest::Url::parse(uri.as_str()).ok())
})
}
}
/// Does not do any parsing and just immediately returns an empty result.
pub struct NoOpMatcher();
impl ServerOutputMatcher<()> for NoOpMatcher {
fn match_line(_: &str) -> Option<()> {
Some(())
}
}
fn parse_socket_from(text: &str) -> Option<PathBuf> {
LISTENING_PORT_RE
.captures(text)
.and_then(|cap| cap.get(1).map(|path| PathBuf::from(path.as_str())))
}
fn parse_port_from(text: &str) -> Option<u16> {
LISTENING_PORT_RE.captures(text).and_then(|cap| {
cap.get(1)
.and_then(|path| path.as_str().parse::<u16>().ok())
})
}

View file

@ -0,0 +1,726 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::constants::{CONTROL_PORT, LAUNCHER_VERSION, PROTOCOL_VERSION};
use crate::log;
use crate::state::LauncherPaths;
use crate::update::Update;
use crate::update_service::Platform;
use crate::util::errors::{
wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError,
};
use crate::util::sync::{new_barrier, Barrier};
use opentelemetry::trace::SpanKind;
use opentelemetry::KeyValue;
use serde::Serialize;
use std::convert::Infallible;
use std::env;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::pin;
use tokio::sync::{mpsc, oneshot, Mutex};
use super::code_server::{
AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw, SocketCodeServer,
};
use super::dev_tunnels::ActiveTunnel;
use super::paths::prune_stopped_servers;
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse,
ForwardParams, ForwardResult, GetHostnameResponse, RefServerMessageParams, ResponseError,
ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod, SuccessResponse,
ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult, VersionParams,
};
use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge};
type ServerBridgeList = Option<Vec<(u16, ServerBridge)>>;
type ServerBridgeListLock = Arc<Mutex<ServerBridgeList>>;
struct HandlerContext {
/// Exit barrier for the socket.
closer: Barrier<()>,
/// Log handle for the server
log: log::Logger,
/// A loopback channel to talk to the TCP server task.
server_tx: mpsc::Sender<ServerSignal>,
/// A loopback channel to talk to the socket server task.
socket_tx: mpsc::Sender<SocketSignal>,
/// Configured launcher paths.
launcher_paths: LauncherPaths,
/// Connected VS Code Server
code_server: Option<SocketCodeServer>,
/// Potentially many "websocket" connections to client
server_bridges: ServerBridgeListLock,
// the cli arguments used to start the code server
code_server_args: CodeServerArgs,
/// counter for the number of bytes received from the socket
rx_counter: Arc<AtomicUsize>,
/// port forwarding functionality
port_forwarding: PortForwarding,
/// install platform for the VS Code server
platform: Platform,
}
impl HandlerContext {
async fn dispose(self) {
let bridges: ServerBridgeList = {
let mut lock = self.server_bridges.lock().await;
let bridges = lock.take();
*lock = None;
bridges
};
if let Some(b) = bridges {
for (_, bridge) in b {
if let Err(e) = bridge.close().await {
warning!(
self.log,
"Could not properly dispose of connection context: {}",
e
)
} else {
debug!(self.log, "Closed server bridge.");
}
}
}
info!(self.log, "Disposed of connection to running server.");
}
}
enum ServerSignal {
/// Signalled when the server has been updated and we want to respawn.
/// We'd generally need to stop and then restart the launcher, but the
/// program might be managed by a supervisor like systemd. Instead, we
/// will stop the TCP listener and spawn the launcher again as a subprocess
/// with the same arguments we used.
Respawn,
}
struct CloseReason(String);
enum SocketSignal {
/// Signals bytes to send to the socket.
Send(Vec<u8>),
/// Closes the socket (e.g. as a result of an error)
CloseWith(CloseReason),
/// Disposes ServerBridge corresponding to an ID
CloseServerBridge(u16),
}
impl SocketSignal {
fn from_message<T>(msg: &T) -> Self
where
T: Serialize + ?Sized,
{
SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
}
}
impl FromServerMessage for SocketSignal {
fn from_server_message(i: u16, body: &[u8]) -> Self {
SocketSignal::from_message(&ToClientRequest {
id: None,
params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }),
})
}
fn from_closed_server_bridge(i: u16) -> Self {
SocketSignal::CloseServerBridge(i)
}
}
pub struct ServerTermination {
/// Whether the server should be respawned in a new binary (see ServerSignal.Respawn).
pub respawn: bool,
pub tunnel: ActiveTunnel,
}
fn print_listening(log: &log::Logger, tunnel_name: &str) {
debug!(log, "VS Code Server is listening for incoming connections");
let extension_name = "+ms-vscode.remote-server";
let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from(""));
let current_dir = env::current_dir().unwrap_or_else(|_| PathBuf::from(""));
let dir = if home_dir == current_dir {
PathBuf::from("")
} else {
current_dir
};
let mut addr = url::Url::parse("https://insiders.vscode.dev").unwrap();
{
let mut ps = addr.path_segments_mut().unwrap();
ps.push(extension_name);
ps.push(tunnel_name);
for segment in &dir {
let as_str = segment.to_string_lossy();
if !(as_str.len() == 1 && as_str.starts_with(std::path::MAIN_SEPARATOR)) {
ps.push(as_str.as_ref());
}
}
}
let message = &format!("\nOpen this link in your browser {}\n", addr);
log.result(message);
}
// Runs the launcher server. Exits on a ctrl+c or when requested by a user.
// Note that client connections may not be closed when this returns; use
// `close_all_clients()` on the ServerTermination to make this happen.
pub async fn serve(
log: &log::Logger,
mut tunnel: ActiveTunnel,
launcher_paths: &LauncherPaths,
code_server_args: &CodeServerArgs,
platform: Platform,
shutdown_rx: oneshot::Receiver<()>,
) -> Result<ServerTermination, AnyError> {
let mut port = tunnel.add_port_direct(CONTROL_PORT).await?;
print_listening(log, &tunnel.name);
let mut forwarding = PortForwardingProcessor::new();
let (tx, mut rx) = mpsc::channel::<ServerSignal>(4);
let (exit_barrier, signal_exit) = new_barrier();
pin!(shutdown_rx);
loop {
tokio::select! {
_ = &mut shutdown_rx => {
info!(log, "Received interrupt, shutting down...");
drop(signal_exit);
return Ok(ServerTermination {
respawn: false,
tunnel,
});
},
c = rx.recv() => {
if let Some(ServerSignal::Respawn) = c {
drop(signal_exit);
return Ok(ServerTermination {
respawn: true,
tunnel,
});
}
},
Some(w) = forwarding.recv() => {
forwarding.process(w, &mut tunnel).await;
},
l = port.recv() => {
let socket = match l {
Some(p) => p,
None => {
warning!(log, "ssh tunnel disposed, tearing down");
return Ok(ServerTermination {
respawn: false,
tunnel,
});
}
};
let own_log = log.prefixed(&log::new_rpc_prefix());
let own_tx = tx.clone();
let own_paths = launcher_paths.clone();
let own_exit = exit_barrier.clone();
let own_code_server_args = code_server_args.clone();
let own_forwarding = forwarding.handle();
tokio::spawn(async move {
use opentelemetry::trace::{FutureExt, TraceContextExt};
let span = own_log.span("server.socket").with_kind(SpanKind::Consumer).start(own_log.tracer());
let cx = opentelemetry::Context::current_with_span(span);
let serve_at = Instant::now();
debug!(own_log, "Serving new connection");
let (writehalf, readhalf) = socket.into_split();
let stats = process_socket(own_exit, readhalf, writehalf, own_log, own_tx, own_paths, own_code_server_args, own_forwarding, platform).with_context(cx.clone()).await;
cx.span().add_event(
"socket.bandwidth",
vec![
KeyValue::new("tx", stats.tx as f64),
KeyValue::new("rx", stats.rx as f64),
KeyValue::new("duration_ms", serve_at.elapsed().as_millis() as f64),
],
);
cx.span().end();
});
}
}
}
}
struct SocketStats {
rx: usize,
tx: usize,
}
#[allow(clippy::too_many_arguments)] // necessary here
async fn process_socket(
mut exit_barrier: Barrier<()>,
readhalf: impl AsyncRead + Send + Unpin + 'static,
mut writehalf: impl AsyncWrite + Unpin,
log: log::Logger,
server_tx: mpsc::Sender<ServerSignal>,
launcher_paths: LauncherPaths,
code_server_args: CodeServerArgs,
port_forwarding: PortForwarding,
platform: Platform,
) -> SocketStats {
let (socket_tx, mut socket_rx) = mpsc::channel(4);
let rx_counter = Arc::new(AtomicUsize::new(0));
let server_bridges: ServerBridgeListLock = Arc::new(Mutex::new(Some(vec![])));
let server_bridges_lock = Arc::clone(&server_bridges);
let barrier_ctx = exit_barrier.clone();
let log_ctx = log.clone();
let rx_counter_ctx = rx_counter.clone();
tokio::spawn(async move {
let mut ctx = HandlerContext {
closer: barrier_ctx,
server_tx,
socket_tx,
log: log_ctx,
launcher_paths,
code_server_args,
rx_counter: rx_counter_ctx,
code_server: None,
server_bridges: server_bridges_lock,
port_forwarding,
platform,
};
send_version(&ctx.socket_tx).await;
if let Err(e) = handle_socket_read(readhalf, &mut ctx).await {
debug!(ctx.log, "closing socket reader: {}", e);
ctx.socket_tx
.send(SocketSignal::CloseWith(CloseReason(format!("{}", e))))
.await
.ok();
}
ctx.dispose().await;
});
let mut tx_counter = 0;
loop {
tokio::select! {
_ = exit_barrier.wait() => {
writehalf.shutdown().await.ok();
break;
},
recv = socket_rx.recv() => match recv {
None => break,
Some(message) => match message {
SocketSignal::Send(bytes) => {
tx_counter += bytes.len();
if let Err(e) = writehalf.write_all(&bytes).await {
debug!(log, "Closing connection: {}", e);
break;
}
}
SocketSignal::CloseWith(reason) => {
debug!(log, "Closing connection: {}", reason.0);
break;
}
SocketSignal::CloseServerBridge(id) => {
let mut lock = server_bridges.lock().await;
match &mut *lock {
Some(bridges) => {
if let Some(index) = bridges.iter().position(|(i, _)| *i == id) {
(*bridges).remove(index as usize);
}
},
None => {}
}
}
}
}
}
}
SocketStats {
tx: tx_counter,
rx: rx_counter.load(Ordering::Acquire),
}
}
async fn send_version(tx: &mpsc::Sender<SocketSignal>) {
tx.send(SocketSignal::from_message(&ToClientRequest {
id: None,
params: ClientRequestMethod::version(VersionParams {
version: LAUNCHER_VERSION.unwrap_or("dev"),
protocol_version: PROTOCOL_VERSION,
}),
}))
.await
.ok();
}
async fn handle_socket_read(
readhalf: impl AsyncRead + Unpin,
ctx: &mut HandlerContext,
) -> Result<(), std::io::Error> {
let mut socket_reader = BufReader::new(readhalf);
let mut decode_buf = vec![];
let mut did_update = false;
let result = loop {
match read_next(&mut socket_reader, ctx, &mut decode_buf, &mut did_update).await {
Ok(false) => break Ok(()),
Ok(true) => { /* continue */ }
Err(e) => break Err(e),
}
};
// The connection is now closed, asked to respawn if needed
if did_update {
ctx.server_tx.send(ServerSignal::Respawn).await.ok();
}
result
}
/// Reads and handles the next data packet, returns true if the read loop should continue.
async fn read_next(
socket_reader: &mut BufReader<impl AsyncRead + Unpin>,
ctx: &mut HandlerContext,
decode_buf: &mut Vec<u8>,
did_update: &mut bool,
) -> Result<bool, std::io::Error> {
let msg_length = tokio::select! {
u = socket_reader.read_u32() => u? as usize,
_ = ctx.closer.wait() => return Ok(false),
};
decode_buf.resize(msg_length, 0);
ctx.rx_counter
.fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed);
tokio::select! {
r = socket_reader.read_exact(decode_buf) => r?,
_ = ctx.closer.wait() => return Ok(false),
};
let req = match rmp_serde::from_slice::<ToServerRequest>(decode_buf) {
Ok(req) => req,
Err(e) => {
warning!(ctx.log, "Error decoding message: {}", e);
return Ok(true); // not fatal
}
};
let log = ctx.log.prefixed(
req.id
.map(|id| format!("[call.{}]", id))
.as_deref()
.unwrap_or("notify"),
);
macro_rules! success {
($r:expr) => {
req.id
.map(|id| rmp_serde::to_vec_named(&SuccessResponse { id, result: &$r }))
};
}
macro_rules! tj {
($name:expr, $e:expr) => {
match (spanf!(
log,
log.span(&format!("call.{}", $name))
.with_kind(opentelemetry::trace::SpanKind::Server),
$e
)) {
Ok(r) => success!(r),
Err(e) => {
warning!(log, "error handling call: {:?}", e);
req.id.map(|id| {
rmp_serde::to_vec_named(&ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{:?}", e),
},
})
})
}
}
};
}
let response = match req.params {
ServerRequestMethod::ping(_) => success!(EmptyResult {}),
ServerRequestMethod::serve(p) => tj!("serve", handle_serve(ctx, &log, p)),
ServerRequestMethod::prune => tj!("prune", handle_prune(ctx)),
ServerRequestMethod::gethostname(_) => tj!("gethostname", handle_get_hostname()),
ServerRequestMethod::update(p) => tj!("update", async {
let r = handle_update(ctx, &p).await;
if matches!(&r, Ok(u) if u.did_update) {
*did_update = true;
}
r
}),
ServerRequestMethod::servermsg(m) => {
if let Err(e) = handle_server_message(ctx, m).await {
warning!(log, "error handling call: {:?}", e);
}
None
}
ServerRequestMethod::callserverhttp(p) => {
tj!("callserverhttp", handle_call_server_http(ctx, p))
}
ServerRequestMethod::forward(p) => tj!("forward", handle_forward(ctx, p)),
ServerRequestMethod::unforward(p) => tj!("unforward", handle_unforward(ctx, p)),
};
if let Some(Ok(res)) = response {
if ctx.socket_tx.send(SocketSignal::Send(res)).await.is_err() {
return Ok(false);
}
}
Ok(true)
}
#[derive(Clone)]
struct ServerOutputSink {
tx: mpsc::Sender<SocketSignal>,
}
impl log::LogSink for ServerOutputSink {
fn write_log(&self, level: log::Level, _prefix: &str, message: &str) {
let s = SocketSignal::from_message(&ToClientRequest {
id: None,
params: ClientRequestMethod::serverlog(ServerLog {
line: message,
level: level.to_u8(),
}),
});
self.tx.try_send(s).ok();
}
fn write_result(&self, _message: &str) {}
}
async fn handle_serve(
ctx: &mut HandlerContext,
log: &log::Logger,
params: ServeParams,
) -> Result<EmptyResult, AnyError> {
let mut code_server_args = ctx.code_server_args.clone();
// fill params.extensions into code_server_args.install_extensions
code_server_args
.install_extensions
.extend(params.extensions.into_iter());
let resolved = ServerParamsRaw {
commit_id: params.commit_id,
quality: params.quality,
code_server_args,
headless: true,
platform: ctx.platform,
}
.resolve(log)
.await?;
if ctx.code_server.is_none() {
let install_log = log.tee(ServerOutputSink {
tx: ctx.socket_tx.clone(),
});
let sb = ServerBuilder::new(&install_log, &resolved, &ctx.launcher_paths);
let server = match sb.get_running().await? {
Some(AnyCodeServer::Socket(s)) => s,
Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())),
None => {
sb.setup().await?;
sb.listen_on_default_socket().await?
}
};
ctx.code_server = Some(server);
}
attach_server_bridge(ctx, params.socket_id).await?;
Ok(EmptyResult {})
}
async fn attach_server_bridge(ctx: &mut HandlerContext, socket_id: u16) -> Result<u16, AnyError> {
let attached_fut = ServerBridge::new(
&ctx.code_server.as_ref().unwrap().socket,
socket_id,
&ctx.socket_tx,
)
.await;
match attached_fut {
Ok(a) => {
let mut lock = ctx.server_bridges.lock().await;
match &mut *lock {
Some(server_bridges) => (*server_bridges).push((socket_id, a)),
None => *lock = Some(vec![(socket_id, a)]),
}
trace!(ctx.log, "Attached to server");
Ok(socket_id)
}
Err(e) => Err(e),
}
}
async fn handle_server_message(
ctx: &mut HandlerContext,
params: ServerMessageParams,
) -> Result<EmptyResult, AnyError> {
let mut lock = ctx.server_bridges.lock().await;
match &mut *lock {
Some(server_bridges) => {
let matched_bridge = server_bridges.iter_mut().find(|(id, _)| *id == params.i);
match matched_bridge {
Some((_, sb)) => sb
.write(params.body)
.await
.map_err(|_| AnyError::from(ServerWriteError()))?,
None => return Err(AnyError::from(NoAttachedServerError())),
}
}
None => return Err(AnyError::from(NoAttachedServerError())),
}
Ok(EmptyResult {})
}
async fn handle_prune(ctx: &HandlerContext) -> Result<Vec<String>, AnyError> {
prune_stopped_servers(&ctx.launcher_paths).map(|v| {
v.iter()
.map(|p| p.server_dir.display().to_string())
.collect()
})
}
async fn handle_update(
ctx: &HandlerContext,
params: &UpdateParams,
) -> Result<UpdateResult, AnyError> {
let updater = Update::new();
let latest_release = updater.get_latest_release().await?;
let up_to_date = match LAUNCHER_VERSION {
Some(v) => v == latest_release.version,
None => true,
};
if !params.do_update || up_to_date {
return Ok(UpdateResult {
up_to_date,
did_update: false,
});
}
info!(ctx.log, "Updating CLI from {}", latest_release.version);
let current_exe = std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?;
updater
.switch_to_release(&latest_release, &current_exe)
.await?;
Ok(UpdateResult {
up_to_date: true,
did_update: true,
})
}
async fn handle_get_hostname() -> Result<GetHostnameResponse, Infallible> {
Ok(GetHostnameResponse {
value: gethostname::gethostname().to_string_lossy().into_owned(),
})
}
async fn handle_forward(
ctx: &HandlerContext,
params: ForwardParams,
) -> Result<ForwardResult, AnyError> {
info!(ctx.log, "Forwarding port {}", params.port);
let uri = ctx.port_forwarding.forward(params.port).await?;
Ok(ForwardResult { uri })
}
async fn handle_unforward(
ctx: &HandlerContext,
params: UnforwardParams,
) -> Result<EmptyResult, AnyError> {
info!(ctx.log, "Unforwarding port {}", params.port);
ctx.port_forwarding.unforward(params.port).await?;
Ok(EmptyResult {})
}
async fn handle_call_server_http(
ctx: &HandlerContext,
params: CallServerHttpParams,
) -> Result<CallServerHttpResult, AnyError> {
use hyper::{body, client::conn::Builder, Body, Request};
// We use Hyper directly here since reqwest doesn't support sockets/pipes.
// See https://github.com/seanmonstar/reqwest/issues/39
let socket = match &ctx.code_server {
Some(cs) => &cs.socket,
None => return Err(AnyError::from(NoAttachedServerError())),
};
let rw = get_socket_rw_stream(socket).await?;
let (mut request_sender, connection) = Builder::new()
.handshake(rw)
.await
.map_err(|e| wrap(e, "error establishing connection"))?;
// start the connection processing; it's shut down when the sender is dropped
tokio::spawn(connection);
let mut request_builder = Request::builder()
.method::<&str>(params.method.as_ref())
.uri(format!("http://127.0.0.1{}", params.path))
.header("Host", "127.0.0.1");
for (k, v) in params.headers {
request_builder = request_builder.header(k, v);
}
let request = request_builder
.body(Body::from(params.body.unwrap_or_default()))
.map_err(|e| wrap(e, "invalid request"))?;
let response = request_sender
.send_request(request)
.await
.map_err(|e| wrap(e, "error sending request"))?;
Ok(CallServerHttpResult {
status: response.status().as_u16(),
headers: response
.headers()
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect(),
body: body::to_bytes(response)
.await
.map_err(|e| wrap(e, "error reading response body"))?
.to_vec(),
})
}

View file

@ -0,0 +1,822 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::auth;
use crate::constants::{CONTROL_PORT, TUNNEL_SERVICE_USER_AGENT};
use crate::state::{LauncherPaths, PersistedState};
use crate::util::errors::{
wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError,
};
use crate::util::input::prompt_placeholder;
use crate::{debug, info, log, spanf, trace, warning};
use async_trait::async_trait;
use futures::TryFutureExt;
use rand::prelude::IteratorRandom;
use regex::Regex;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tunnels::connections::{ForwardedPortConnection, HostRelay};
use tunnels::contracts::{
Tunnel, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN, TUNNEL_PROTOCOL_AUTO,
};
use tunnels::management::{
new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
NO_REQUEST_OPTIONS,
};
use super::name_generator;
#[derive(Clone, Serialize, Deserialize)]
pub struct PersistedTunnel {
pub name: String,
pub id: String,
pub cluster: String,
}
impl PersistedTunnel {
pub fn into_locator(self) -> TunnelLocator {
TunnelLocator::ID {
cluster: self.cluster,
id: self.id,
}
}
pub fn locator(&self) -> TunnelLocator {
TunnelLocator::ID {
cluster: self.cluster.clone(),
id: self.id.clone(),
}
}
}
#[async_trait]
trait AccessTokenProvider: Send + Sync {
/// Gets the current access token.
async fn refresh_token(&self) -> Result<String, WrappedError>;
}
/// Access token provider that provides a fixed token without refreshing.
struct StaticAccessTokenProvider(String);
impl StaticAccessTokenProvider {
pub fn new(token: String) -> Self {
Self(token)
}
}
#[async_trait]
impl AccessTokenProvider for StaticAccessTokenProvider {
async fn refresh_token(&self) -> Result<String, WrappedError> {
Ok(self.0.clone())
}
}
/// Access token provider that looks up the token from the tunnels API.
struct LookupAccessTokenProvider {
client: TunnelManagementClient,
locator: TunnelLocator,
log: log::Logger,
initial_token: Arc<Mutex<Option<String>>>,
}
impl LookupAccessTokenProvider {
pub fn new(
client: TunnelManagementClient,
locator: TunnelLocator,
log: log::Logger,
initial_token: Option<String>,
) -> Self {
Self {
client,
locator,
log,
initial_token: Arc::new(Mutex::new(initial_token)),
}
}
}
#[async_trait]
impl AccessTokenProvider for LookupAccessTokenProvider {
async fn refresh_token(&self) -> Result<String, WrappedError> {
if let Some(token) = self.initial_token.lock().unwrap().take() {
return Ok(token);
}
let tunnel_lookup = spanf!(
self.log,
self.log.span("dev-tunnel.tag.get"),
self.client.get_tunnel(
&self.locator,
&TunnelRequestOptions {
token_scopes: vec!["host".to_string()],
..Default::default()
}
)
);
trace!(self.log, "Successfully refreshed access token");
match tunnel_lookup {
Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
Err(e) => Err(wrap(e, "failed to lookup tunnel")),
}
}
}
#[derive(Clone)]
pub struct DevTunnels {
log: log::Logger,
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
client: TunnelManagementClient,
}
/// Representation of a tunnel returned from the `start` methods.
pub struct ActiveTunnel {
/// Name of the tunnel
pub name: String,
manager: ActiveTunnelManager,
}
impl ActiveTunnel {
/// Closes and unregisters the tunnel.
pub async fn close(&mut self) -> Result<(), AnyError> {
self.manager.kill().await?;
Ok(())
}
/// Forwards a port to local connections.
pub async fn add_port_direct(
&mut self,
port_number: u16,
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
let port = self.manager.add_port_direct(port_number).await?;
Ok(port)
}
/// Forwards a port over TCP.
pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> {
self.manager.add_port_tcp(port_number).await?;
Ok(())
}
/// Removes a forwarded port TCP.
pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> {
self.manager.remove_port(port_number).await?;
Ok(())
}
/// Gets the public URI on which a forwarded port can be access in browser.
pub async fn get_port_uri(&mut self, port: u16) -> Result<String, AnyError> {
let endpoint = self.manager.get_endpoint().await?;
let format = endpoint
.base
.port_uri_format
.expect("expected to have port format");
Ok(format.replace(PORT_TOKEN, &port.to_string()))
}
}
const LAUNCHER_TUNNEL_TAG: &str = "vscode-server-launcher";
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
tunnel
.access_tokens
.as_ref()
.expect("expected to have access tokens")
.get("host")
.expect("expected to have host token")
.to_string()
}
fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
if name.len() > MAX_TUNNEL_NAME_LENGTH {
return Err(InvalidTunnelName(format!(
"Names cannot be longer than {} characters. Please try a different name.",
MAX_TUNNEL_NAME_LENGTH
)));
}
let re = Regex::new(r"^([\w-]+)$").unwrap();
if !re.is_match(name) {
return Err(InvalidTunnelName(
"Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
));
}
Ok(())
}
/// Structure optionally passed into `start_existing_tunnel` to forward an existing tunnel.
#[derive(Clone, Debug)]
pub struct ExistingTunnel {
/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
pub tunnel_name: String,
/// Token to authenticate and use preexisting tunnel
pub host_token: String,
/// Id of preexisting tunnel to use to connect to the VS Code Server
pub tunnel_id: String,
/// Cluster of preexisting tunnel to use to connect to the VS Code Server
pub cluster: String,
}
impl DevTunnels {
pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth);
DevTunnels {
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
}
}
pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
let tunnel = match self.launcher_tunnel.load() {
Some(t) => t,
None => {
return Ok(());
}
};
spanf!(
self.log,
self.log.span("dev-tunnel.delete"),
self.client
.delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
self.launcher_tunnel.save(None)?;
Ok(())
}
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
is_valid_name(name)?;
let existing = spanf!(
self.log,
self.log.span("dev-tunnel.rename.search"),
self.client.list_all_tunnels(&TunnelRequestOptions {
tags: vec![LAUNCHER_TUNNEL_TAG.to_string(), name.to_string()],
require_all_tags: true,
..Default::default()
})
)
.map_err(|e| wrap(e, "failed to list existing tunnels"))?;
if !existing.is_empty() {
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
"tunnel name already in use".to_string(),
)));
}
let mut tunnel = match self.launcher_tunnel.load() {
Some(t) => t,
None => {
debug!(self.log, "No code server tunnel found, creating new one");
let (persisted, _) = self.create_tunnel(name).await?;
self.launcher_tunnel.save(Some(persisted))?;
return Ok(());
}
};
let locator = tunnel.locator();
let mut full_tunnel = spanf!(
self.log,
self.log.span("dev-tunnel.tag.get"),
self.client.get_tunnel(&locator, NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to lookup tunnel"))?;
full_tunnel.tags = vec![name.to_string(), LAUNCHER_TUNNEL_TAG.to_string()];
spanf!(
self.log,
self.log.span("dev-tunnel.tag.update"),
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to update tunnel tags"))?;
tunnel.name = name.to_string();
self.launcher_tunnel.save(Some(tunnel.clone()))?;
Ok(())
}
/// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`,
/// this attempts to reuse or generate a friendly tunnel name.
pub async fn start_new_launcher_tunnel(
&mut self,
use_random_name: bool,
) -> Result<ActiveTunnel, AnyError> {
let (tunnel, persisted) = match self.launcher_tunnel.load() {
Some(persisted) => {
let tunnel_lookup = spanf!(
self.log,
self.log.span("dev-tunnel.tag.get"),
self.client.get_tunnel(
&persisted.locator(),
&TunnelRequestOptions {
include_ports: true,
token_scopes: vec!["host".to_string()],
..Default::default()
}
)
);
match tunnel_lookup {
Ok(ft) => (ft, persisted),
Err(HttpError::ResponseError(e))
if e.status_code == StatusCode::NOT_FOUND
|| e.status_code == StatusCode::FORBIDDEN =>
{
let (persisted, tunnel) = self.create_tunnel(&persisted.name).await?;
self.launcher_tunnel.save(Some(persisted.clone()))?;
(tunnel, persisted)
}
Err(e) => return Err(AnyError::from(wrap(e, "failed to lookup tunnel"))),
}
}
None => {
debug!(self.log, "No code server tunnel found, creating new one");
let name = self.get_name_for_tunnel(use_random_name).await?;
let (persisted, full_tunnel) = self.create_tunnel(&name).await?;
self.launcher_tunnel.save(Some(persisted.clone()))?;
(full_tunnel, persisted)
}
};
let locator = TunnelLocator::try_from(&tunnel).unwrap();
let host_token = get_host_token_from_tunnel(&tunnel);
for port_to_delete in tunnel
.ports
.iter()
.filter(|p| p.port_number != CONTROL_PORT)
{
let output_fut = self.client.delete_tunnel_port(
&locator,
port_to_delete.port_number,
NO_REQUEST_OPTIONS,
);
spanf!(
self.log,
self.log.span("dev-tunnel.port.delete"),
output_fut
)
.map_err(|e| wrap(e, "failed to delete port"))?;
}
// cleanup any old trailing tunnel endpoints
for endpoint in tunnel.endpoints {
let fut = self.client.delete_tunnel_endpoints(
&locator,
&endpoint.host_id,
None,
NO_REQUEST_OPTIONS,
);
spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
.map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
}
self.start_tunnel(
locator.clone(),
&persisted,
self.client.clone(),
LookupAccessTokenProvider::new(
self.client.clone(),
locator,
self.log.clone(),
Some(host_token),
),
)
.await
}
async fn create_tunnel(&mut self, name: &str) -> Result<(PersistedTunnel, Tunnel), AnyError> {
info!(self.log, "Creating tunnel with the name: {}", name);
let mut tried_recycle = false;
let new_tunnel = Tunnel {
tags: vec![name.to_string(), LAUNCHER_TUNNEL_TAG.to_string()],
..Default::default()
};
loop {
let result = spanf!(
self.log,
self.log.span("dev-tunnel.create"),
self.client.create_tunnel(&new_tunnel, NO_REQUEST_OPTIONS)
);
match result {
Err(HttpError::ResponseError(e))
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
{
if !tried_recycle && self.try_recycle_tunnel().await? {
tried_recycle = true;
continue;
}
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
"You've exceeded the 10 machine limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
)));
}
Err(e) => {
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
format!("{:?}", e),
)))
}
Ok(t) => {
return Ok((
PersistedTunnel {
cluster: t.cluster_id.clone().unwrap(),
id: t.tunnel_id.clone().unwrap(),
name: name.to_string(),
},
t,
))
}
}
}
}
/// Tries to delete an unused tunnel, and then creates a tunnel with the
/// given `new_name`.
async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
trace!(
self.log,
"Tunnel limit hit, trying to recycle an old tunnel"
);
let existing_tunnels = self.list_all_server_tunnels().await?;
let recyclable = existing_tunnels
.iter()
.filter(|t| {
t.status
.as_ref()
.and_then(|s| s.host_connection_count.as_ref())
.map(|c| c.get_count())
.unwrap_or(0)
== 0
})
.choose(&mut rand::thread_rng());
match recyclable {
Some(tunnel) => {
trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
spanf!(
self.log,
self.log.span("dev-tunnel.delete"),
self.client
.delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
Ok(true)
}
None => {
trace!(self.log, "No tunnels available to recycle");
Ok(false)
}
}
}
async fn list_all_server_tunnels(&mut self) -> Result<Vec<Tunnel>, AnyError> {
let tunnels = spanf!(
self.log,
self.log.span("dev-tunnel.listall"),
self.client.list_all_tunnels(&TunnelRequestOptions {
tags: vec![LAUNCHER_TUNNEL_TAG.to_string()],
require_all_tags: true,
..Default::default()
})
)
.map_err(|e| wrap(e, "error listing current tunnels"))?;
Ok(tunnels)
}
async fn get_name_for_tunnel(&mut self, use_random_name: bool) -> Result<String, AnyError> {
let mut placeholder_name = name_generator::generate_name(MAX_TUNNEL_NAME_LENGTH);
let existing_tunnels = self.list_all_server_tunnels().await?;
let is_name_free = |n: &str| {
!existing_tunnels
.iter()
.any(|v| v.tags.iter().any(|t| t == n))
};
if use_random_name {
while !is_name_free(&placeholder_name) {
placeholder_name = name_generator::generate_name(MAX_TUNNEL_NAME_LENGTH);
}
return Ok(placeholder_name);
}
loop {
let name = prompt_placeholder(
"What would you like to call this machine?",
&placeholder_name,
)?;
if let Err(e) = is_valid_name(&name) {
info!(self.log, "{}", e);
continue;
}
if is_name_free(&name) {
return Ok(name);
}
info!(self.log, "The name {} is already in use", name);
}
}
/// Hosts an existing tunnel, where the tunnel ID and host token are given.
pub async fn start_existing_tunnel(
&mut self,
tunnel: ExistingTunnel,
) -> Result<ActiveTunnel, AnyError> {
let tunnel_details = PersistedTunnel {
name: tunnel.tunnel_name,
id: tunnel.tunnel_id,
cluster: tunnel.cluster,
};
let mut mgmt = self.client.build();
mgmt.authorization(tunnels::management::Authorization::Tunnel(
tunnel.host_token.clone(),
));
self.start_tunnel(
tunnel_details.locator(),
&tunnel_details,
mgmt.into(),
StaticAccessTokenProvider::new(tunnel.host_token),
)
.await
}
async fn start_tunnel(
&mut self,
locator: TunnelLocator,
tunnel_details: &PersistedTunnel,
client: TunnelManagementClient,
access_token: impl AccessTokenProvider + 'static,
) -> Result<ActiveTunnel, AnyError> {
let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
let endpoint_result = spanf!(
self.log,
self.log.span("dev-tunnel.serve.callback"),
manager.get_endpoint()
);
let endpoint = match endpoint_result {
Ok(endpoint) => endpoint,
Err(e) => {
error!(self.log, "Error connecting to tunnel endpoint: {}", e);
manager.kill().await.ok();
return Err(e);
}
};
debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
Ok(ActiveTunnel {
name: tunnel_details.name.clone(),
manager,
})
}
}
struct ActiveTunnelManager {
close_tx: Option<mpsc::Sender<()>>,
endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
relay: Arc<tokio::sync::Mutex<HostRelay>>,
}
impl ActiveTunnelManager {
pub fn new(
log: log::Logger,
mgmt: TunnelManagementClient,
locator: TunnelLocator,
access_token: impl AccessTokenProvider + 'static,
) -> ActiveTunnelManager {
let (endpoint_tx, endpoint_rx) = watch::channel(None);
let (close_tx, close_rx) = mpsc::channel(1);
let relay = Arc::new(tokio::sync::Mutex::new(HostRelay::new(locator, mgmt)));
let relay_spawned = relay.clone();
tokio::spawn(async move {
ActiveTunnelManager::spawn_tunnel(
log,
relay_spawned,
close_rx,
endpoint_tx,
access_token,
)
.await;
});
ActiveTunnelManager {
endpoint_rx,
relay,
close_tx: Some(close_tx),
}
}
/// Adds a port for TCP/IP forwarding.
#[allow(dead_code)] // todo: port forwarding
pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), WrappedError> {
self.relay
.lock()
.await
.add_port(&TunnelPort {
port_number,
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
..Default::default()
})
.await
.map_err(|e| wrap(e, "error adding port to relay"))?;
Ok(())
}
/// Adds a port for TCP/IP forwarding.
pub async fn add_port_direct(
&self,
port_number: u16,
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
self.relay
.lock()
.await
.add_port_raw(&TunnelPort {
port_number,
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
..Default::default()
})
.await
.map_err(|e| wrap(e, "error adding port to relay"))
}
/// Removes a port from TCP/IP forwarding.
pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
self.relay
.lock()
.await
.remove_port(port_number)
.await
.map_err(|e| wrap(e, "error remove port from relay"))
}
/// Gets the most recent details from the tunnel process. Returns None if
/// the process exited before providing details.
pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
loop {
if let Some(details) = &*self.endpoint_rx.borrow() {
return details.clone().map_err(AnyError::from);
}
if self.endpoint_rx.changed().await.is_err() {
return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
}
}
}
/// Kills the process, and waits for it to exit.
/// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works
pub async fn kill(&mut self) -> Result<(), AnyError> {
if let Some(tx) = self.close_tx.take() {
drop(tx);
}
self.relay
.lock()
.await
.unregister()
.await
.map_err(|e| wrap(e, "error unregistering relay"))?;
while self.endpoint_rx.changed().await.is_ok() {}
Ok(())
}
async fn spawn_tunnel(
log: log::Logger,
relay: Arc<tokio::sync::Mutex<HostRelay>>,
mut close_rx: mpsc::Receiver<()>,
endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
access_token_provider: impl AccessTokenProvider + 'static,
) {
let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
macro_rules! fail {
($e: expr, $msg: expr) => {
warning!(log, "{}: {}", $msg, $e);
endpoint_tx.send(Some(Err($e))).ok();
backoff.delay().await;
};
}
loop {
debug!(log, "Starting tunnel to server...");
let access_token = match access_token_provider.refresh_token().await {
Ok(t) => t,
Err(e) => {
fail!(e, "Error refreshing access token, will retry");
continue;
}
};
// we don't bother making a client that can refresh the token, since
// the tunnel won't be able to host as soon as the access token expires.
let handle_res = {
let mut relay = relay.lock().await;
relay
.connect(&access_token)
.await
.map_err(|e| wrap(e, "error connecting to tunnel"))
};
let mut handle = match handle_res {
Ok(handle) => handle,
Err(e) => {
fail!(e, "Error connecting to relay, will retry");
continue;
}
};
backoff.reset();
endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
tokio::select! {
// error is mapped like this prevent it being used across an await,
// which Rust dislikes since there's a non-sendable dyn Error in there
res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
if let Err(e) = res {
fail!(e, "Tunnel exited unexpectedly, reconnecting");
} else {
warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
backoff.delay().await;
}
},
_ = close_rx.recv() => {
trace!(log, "Tunnel closing gracefully");
trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
break;
}
}
}
}
}
struct Backoff {
failures: u32,
base_duration: Duration,
max_duration: Duration,
}
impl Backoff {
pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
Self {
failures: 0,
base_duration,
max_duration,
}
}
pub async fn delay(&mut self) {
tokio::time::sleep(self.next()).await
}
pub fn next(&mut self) -> Duration {
self.failures += 1;
let duration = self
.base_duration
.checked_mul(self.failures)
.unwrap_or(self.max_duration);
std::cmp::min(duration, self.max_duration)
}
pub fn reset(&mut self) {
self.failures = 0;
}
}

View file

@ -0,0 +1,56 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::state::{LauncherPaths, PersistedState};
use crate::util::errors::{AnyError, MissingLegalConsent};
use crate::util::input::prompt_yn;
use serde::{Deserialize, Serialize};
const LICENSE_TEXT: Option<&'static str> = option_env!("LAUNCHER_REMOTE_LICENSE_TEXT");
const LICENSE_PROMPT: Option<&'static str> = option_env!("LAUNCHER_REMOTE_LICENSE_PROMPT");
#[derive(Clone, Default, Serialize, Deserialize)]
struct PersistedConsent {
pub consented: Option<bool>,
}
pub fn require_consent(paths: &LauncherPaths) -> Result<(), AnyError> {
match LICENSE_TEXT {
Some(t) => println!("{}", t),
None => return Ok(()),
}
let prompt = match LICENSE_PROMPT {
Some(p) => p,
None => return Ok(()),
};
let license: PersistedState<PersistedConsent> =
PersistedState::new(paths.root().join("license_consent.json"));
let mut save = false;
let mut load = license.load();
if !load.consented.unwrap_or(false) {
match prompt_yn(prompt) {
Ok(true) => {
save = true;
load.consented = Some(true);
}
Ok(false) => {
return Err(AnyError::from(MissingLegalConsent(
"Sorry you cannot use VS Code Server CLI without accepting the terms."
.to_string(),
)))
}
Err(e) => return Err(AnyError::from(MissingLegalConsent(e.to_string()))),
}
}
if save {
license.save(load)?;
}
Ok(())
}

View file

@ -0,0 +1,218 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use rand::prelude::*;
// Adjectives in LEFT from Moby :
static LEFT: &[&str] = &[
"admiring",
"adoring",
"affectionate",
"agitated",
"amazing",
"angry",
"awesome",
"beautiful",
"blissful",
"bold",
"boring",
"brave",
"busy",
"charming",
"clever",
"cool",
"compassionate",
"competent",
"condescending",
"confident",
"cranky",
"crazy",
"dazzling",
"determined",
"distracted",
"dreamy",
"eager",
"ecstatic",
"elastic",
"elated",
"elegant",
"eloquent",
"epic",
"exciting",
"fervent",
"festive",
"flamboyant",
"focused",
"friendly",
"frosty",
"funny",
"gallant",
"gifted",
"goofy",
"gracious",
"great",
"happy",
"hardcore",
"heuristic",
"hopeful",
"hungry",
"infallible",
"inspiring",
"interesting",
"intelligent",
"jolly",
"jovial",
"keen",
"kind",
"laughing",
"loving",
"lucid",
"magical",
"mystifying",
"modest",
"musing",
"naughty",
"nervous",
"nice",
"nifty",
"nostalgic",
"objective",
"optimistic",
"peaceful",
"pedantic",
"pensive",
"practical",
"priceless",
"quirky",
"quizzical",
"recursing",
"relaxed",
"reverent",
"romantic",
"sad",
"serene",
"sharp",
"silly",
"sleepy",
"stoic",
"strange",
"stupefied",
"suspicious",
"sweet",
"tender",
"thirsty",
"trusting",
"unruffled",
"upbeat",
"vibrant",
"vigilant",
"vigorous",
"wizardly",
"wonderful",
"xenodochial",
"youthful",
"zealous",
"zen",
];
static RIGHT: &[&str] = &[
"albatross",
"antbird",
"antpitta",
"antshrike",
"antwren",
"babbler",
"barbet",
"blackbird",
"brushfinch",
"bulbul",
"bunting",
"cisticola",
"cormorant",
"crow",
"cuckoo",
"dove",
"drongo",
"duck",
"eagle",
"falcon",
"fantail",
"finch",
"flowerpecker",
"flycatcher",
"goose",
"goshawk",
"greenbul",
"grosbeak",
"gull",
"hawk",
"heron",
"honeyeater",
"hornbill",
"hummingbird",
"ibis",
"jay",
"kestrel",
"kingfisher",
"kite",
"lark",
"lorikeet",
"magpie",
"mockingbird",
"monarch",
"nightjar",
"oriole",
"owl",
"parakeet",
"parrot",
"partridge",
"penguin",
"petrel",
"pheasant",
"piculet",
"pigeon",
"pitta",
"prinia",
"puffin",
"quail",
"robin",
"sandpiper",
"seedeater",
"shearwater",
"sparrow",
"spinetail",
"starling",
"sunbird",
"swallow",
"swift",
"swiftlet",
"tanager",
"tapaculo",
"tern",
"thornbill",
"tinamou",
"trogon",
"tyrannulet",
"vireo",
"warbler",
"waxbill",
"weaver",
"whistler",
"woodpecker",
"wren",
];
/// Generates a random avian name, with the optional extra_random_length added
/// to reduce chance of in-flight collisions.
pub fn generate_name(max_length: usize) -> String {
let mut rng = rand::thread_rng();
loop {
let left = LEFT[rng.gen_range(0..LEFT.len())];
let right = RIGHT[rng.gen_range(0..RIGHT.len())];
let s = format!("{}-{}", left, right);
if s.len() < max_length {
return s;
}
}
}

View file

@ -0,0 +1,216 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{
fs::{read_dir, read_to_string, remove_dir_all, write},
path::PathBuf,
};
use serde::{Deserialize, Serialize};
use crate::{
log, options,
state::{LauncherPaths, PersistedState},
util::{
errors::{wrap, AnyError, WrappedError},
machine,
},
};
const INSIDERS_INSTALL_FOLDER: &str = "server-insiders";
const STABLE_INSTALL_FOLDER: &str = "server-stable";
const EXPLORATION_INSTALL_FOLDER: &str = "server-exploration";
const PIDFILE_SUFFIX: &str = ".pid";
const LOGFILE_SUFFIX: &str = ".log";
pub struct ServerPaths {
// Directory into which the server is downloaded
pub server_dir: PathBuf,
// Executable path, within the server_id
pub executable: PathBuf,
// File where logs for the server should be written.
pub logfile: PathBuf,
// File where the process ID for the server should be written.
pub pidfile: PathBuf,
}
impl ServerPaths {
// Queries the system to determine the process ID of the running server.
// Returns the process ID, if the server is running.
pub fn get_running_pid(&self) -> Option<u32> {
if let Some(pid) = self.read_pid() {
return match machine::process_at_path_exists(pid, &self.executable) {
true => Some(pid),
false => None,
};
}
if let Some(pid) = machine::find_running_process(&self.executable) {
// attempt to backfill process ID:
self.write_pid(pid).ok();
return Some(pid);
}
None
}
/// Delete the server directory
pub fn delete(&self) -> Result<(), WrappedError> {
remove_dir_all(&self.server_dir).map_err(|e| {
wrap(
e,
format!("error deleting server dir {}", self.server_dir.display()),
)
})
}
// VS Code Server pid
pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> {
write(&self.pidfile, &format!("{}", pid)).map_err(|e| {
wrap(
e,
format!("error writing process id into {}", self.pidfile.display()),
)
})
}
fn read_pid(&self) -> Option<u32> {
read_to_string(&self.pidfile)
.ok()
.and_then(|s| s.parse::<u32>().ok())
}
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct InstalledServer {
pub quality: options::Quality,
pub commit: String,
pub headless: bool,
}
impl InstalledServer {
/// Gets path information about where a specific server should be stored.
pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths {
let base_folder = self.get_install_folder(p);
let server_dir = base_folder.join("bin").join(&self.commit);
ServerPaths {
executable: server_dir
.join("bin")
.join(self.quality.server_entrypoint()),
server_dir,
logfile: base_folder.join(format!(".{}{}", self.commit, LOGFILE_SUFFIX)),
pidfile: base_folder.join(format!(".{}{}", self.commit, PIDFILE_SUFFIX)),
}
}
fn get_install_folder(&self, p: &LauncherPaths) -> PathBuf {
let name = match self.quality {
options::Quality::Insiders => INSIDERS_INSTALL_FOLDER,
options::Quality::Exploration => EXPLORATION_INSTALL_FOLDER,
options::Quality::Stable => STABLE_INSTALL_FOLDER,
};
p.root().join(if !self.headless {
format!("{}-web", name)
} else {
name.to_string()
})
}
}
pub struct LastUsedServers<'a> {
state: PersistedState<Vec<InstalledServer>>,
paths: &'a LauncherPaths,
}
impl<'a> LastUsedServers<'a> {
pub fn new(paths: &'a LauncherPaths) -> LastUsedServers {
LastUsedServers {
state: PersistedState::new(paths.root().join("last-used-servers.json")),
paths,
}
}
/// Adds a server as having been used most recently. Returns the number of retained server.
pub fn add(&self, server: InstalledServer) -> Result<usize, WrappedError> {
self.state.update_with(server, |server, l| {
if let Some(index) = l.iter().position(|s| s == &server) {
l.remove(index);
}
l.insert(0, server);
l.len()
})
}
/// Trims so that at most `max_servers` are saved on disk.
pub fn trim(&self, log: &log::Logger, max_servers: usize) -> Result<(), WrappedError> {
let mut servers = self.state.load();
while servers.len() > max_servers {
let server = servers.pop().unwrap();
debug!(
log,
"Removing old server {}/{}",
server.quality.get_machine_name(),
server.commit
);
let server_paths = server.server_paths(self.paths);
server_paths.delete()?;
}
self.state.save(servers)?;
Ok(())
}
}
/// Prunes servers not currently running, and returns the deleted servers.
pub fn prune_stopped_servers(launcher_paths: &LauncherPaths) -> Result<Vec<ServerPaths>, AnyError> {
get_all_servers(launcher_paths)
.into_iter()
.map(|s| s.server_paths(launcher_paths))
.filter(|s| s.get_running_pid().is_none())
.map(|s| s.delete().map(|_| s))
.collect::<Result<_, _>>()
.map_err(AnyError::from)
}
// Gets a list of all servers which look like they might be running.
pub fn get_all_servers(lp: &LauncherPaths) -> Vec<InstalledServer> {
let mut servers: Vec<InstalledServer> = vec![];
let mut server = InstalledServer {
commit: "".to_owned(),
headless: false,
quality: options::Quality::Stable,
};
add_server_paths_in_folder(lp, &server, &mut servers);
server.headless = true;
add_server_paths_in_folder(lp, &server, &mut servers);
server.headless = false;
server.quality = options::Quality::Insiders;
add_server_paths_in_folder(lp, &server, &mut servers);
server.headless = true;
add_server_paths_in_folder(lp, &server, &mut servers);
servers
}
fn add_server_paths_in_folder(
lp: &LauncherPaths,
server: &InstalledServer,
servers: &mut Vec<InstalledServer>,
) {
let dir = server.get_install_folder(lp).join("bin");
if let Ok(children) = read_dir(dir) {
for bin in children.flatten() {
servers.push(InstalledServer {
quality: server.quality,
headless: server.headless,
commit: bin.file_name().to_string_lossy().into(),
});
}
}
}

View file

@ -0,0 +1,130 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::collections::HashSet;
use tokio::sync::{mpsc, oneshot};
use crate::{
constants::CONTROL_PORT,
util::errors::{AnyError, CannotForwardControlPort, ServerHasClosed},
};
use super::dev_tunnels::ActiveTunnel;
pub enum PortForwardingRec {
Forward(u16, oneshot::Sender<Result<String, AnyError>>),
Unforward(u16, oneshot::Sender<Result<(), AnyError>>),
}
/// Provides a port forwarding service for connected clients. Clients can make
/// requests on it, which are (and *must be*) processed by calling the `.process()`
/// method on the forwarder.
pub struct PortForwardingProcessor {
tx: mpsc::Sender<PortForwardingRec>,
rx: mpsc::Receiver<PortForwardingRec>,
forwarded: HashSet<u16>,
}
impl PortForwardingProcessor {
pub fn new() -> Self {
let (tx, rx) = mpsc::channel(8);
Self {
tx,
rx,
forwarded: HashSet::new(),
}
}
/// Gets a handle that can be passed off to consumers of port forwarding.
pub fn handle(&self) -> PortForwarding {
PortForwarding {
tx: self.tx.clone(),
}
}
/// Receives port forwarding requests. Consumers MUST call `process()`
/// with the received requests.
pub async fn recv(&mut self) -> Option<PortForwardingRec> {
self.rx.recv().await
}
/// Processes the incoming forwarding request.
pub async fn process(&mut self, req: PortForwardingRec, tunnel: &mut ActiveTunnel) {
match req {
PortForwardingRec::Forward(port, tx) => {
tx.send(self.process_forward(port, tunnel).await).ok();
}
PortForwardingRec::Unforward(port, tx) => {
tx.send(self.process_unforward(port, tunnel).await).ok();
}
}
}
async fn process_unforward(
&mut self,
port: u16,
tunnel: &mut ActiveTunnel,
) -> Result<(), AnyError> {
if port == CONTROL_PORT {
return Err(CannotForwardControlPort().into());
}
tunnel.remove_port(port).await?;
self.forwarded.remove(&port);
Ok(())
}
async fn process_forward(
&mut self,
port: u16,
tunnel: &mut ActiveTunnel,
) -> Result<String, AnyError> {
if port == CONTROL_PORT {
return Err(CannotForwardControlPort().into());
}
if !self.forwarded.contains(&port) {
tunnel.add_port_tcp(port).await?;
self.forwarded.insert(port);
}
tunnel.get_port_uri(port).await
}
}
pub struct PortForwarding {
tx: mpsc::Sender<PortForwardingRec>,
}
impl PortForwarding {
pub async fn forward(&self, port: u16) -> Result<String, AnyError> {
let (tx, rx) = oneshot::channel();
let req = PortForwardingRec::Forward(port, tx);
if self.tx.send(req).await.is_err() {
return Err(ServerHasClosed().into());
}
match rx.await {
Ok(r) => r,
Err(_) => Err(ServerHasClosed().into()),
}
}
pub async fn unforward(&self, port: u16) -> Result<(), AnyError> {
let (tx, rx) = oneshot::channel();
let req = PortForwardingRec::Unforward(port, tx);
if self.tx.send(req).await.is_err() {
return Err(ServerHasClosed().into());
}
match rx.await {
Ok(r) => r,
Err(_) => Err(ServerHasClosed().into()),
}
}
}

View file

@ -0,0 +1,151 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::collections::HashMap;
use crate::options::Quality;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Debug)]
#[serde(tag = "method", content = "params")]
#[allow(non_camel_case_types)]
pub enum ServerRequestMethod {
serve(ServeParams),
prune,
ping(EmptyResult),
forward(ForwardParams),
unforward(UnforwardParams),
gethostname(EmptyResult),
update(UpdateParams),
servermsg(ServerMessageParams),
callserverhttp(CallServerHttpParams),
}
#[derive(Serialize, Debug)]
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
#[allow(non_camel_case_types)]
pub enum ClientRequestMethod<'a> {
servermsg(RefServerMessageParams<'a>),
serverlog(ServerLog<'a>),
version(VersionParams),
}
#[derive(Deserialize, Debug)]
pub struct ForwardParams {
pub port: u16,
}
#[derive(Deserialize, Debug)]
pub struct UnforwardParams {
pub port: u16,
}
#[derive(Serialize)]
pub struct ForwardResult {
pub uri: String,
}
#[derive(Deserialize, Debug)]
pub struct ServeParams {
pub socket_id: u16,
pub commit_id: Option<String>,
pub quality: Quality,
pub extensions: Vec<String>,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct EmptyResult {}
#[derive(Serialize, Deserialize, Debug)]
pub struct UpdateParams {
pub do_update: bool,
}
#[derive(Deserialize, Debug)]
pub struct ServerMessageParams {
pub i: u16,
#[serde(with = "serde_bytes")]
pub body: Vec<u8>,
}
#[derive(Serialize, Debug)]
pub struct RefServerMessageParams<'a> {
pub i: u16,
#[serde(with = "serde_bytes")]
pub body: &'a [u8],
}
#[derive(Serialize)]
pub struct UpdateResult {
pub up_to_date: bool,
pub did_update: bool,
}
#[derive(Deserialize, Debug)]
pub struct ToServerRequest {
pub id: Option<u8>,
#[serde(flatten)]
pub params: ServerRequestMethod,
}
#[derive(Serialize, Debug)]
pub struct ToClientRequest<'a> {
pub id: Option<u8>,
#[serde(flatten)]
pub params: ClientRequestMethod<'a>,
}
#[derive(Serialize, Deserialize)]
pub struct SuccessResponse<T>
where
T: Serialize,
{
pub id: u8,
pub result: T,
}
#[derive(Serialize, Deserialize)]
pub struct ErrorResponse {
pub id: u8,
pub error: ResponseError,
}
#[derive(Serialize, Deserialize)]
pub struct ResponseError {
pub code: i32,
pub message: String,
}
#[derive(Debug, Default, Serialize)]
pub struct ServerLog<'a> {
pub line: &'a str,
pub level: u8,
}
#[derive(Serialize)]
pub struct GetHostnameResponse {
pub value: String,
}
#[derive(Deserialize, Debug)]
pub struct CallServerHttpParams {
pub path: String,
pub method: String,
pub headers: HashMap<String, String>,
pub body: Option<Vec<u8>>,
}
#[derive(Serialize)]
pub struct CallServerHttpResult {
pub status: u16,
#[serde(with = "serde_bytes")]
pub body: Vec<u8>,
pub headers: HashMap<String, String>,
}
#[derive(Serialize, Debug)]
pub struct VersionParams {
pub version: &'static str,
pub protocol_version: u32,
}

View file

@ -0,0 +1,80 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::path::Path;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{unix::OwnedWriteHalf, UnixStream},
sync::mpsc::Sender,
};
use crate::util::errors::{wrap, AnyError};
pub struct ServerBridge {
write: OwnedWriteHalf,
}
pub trait FromServerMessage {
fn from_server_message(index: u16, message: &[u8]) -> Self;
fn from_closed_server_bridge(i: u16) -> Self;
}
pub async fn get_socket_rw_stream(path: &Path) -> Result<UnixStream, AnyError> {
let s = UnixStream::connect(path).await.map_err(|e| {
wrap(
e,
format!(
"error connecting to vscode server socket in {}",
path.display()
),
)
})?;
Ok(s)
}
const BUFFER_SIZE: usize = 65536;
impl ServerBridge {
pub async fn new<T>(path: &Path, index: u16, target: &Sender<T>) -> Result<Self, AnyError>
where
T: 'static + FromServerMessage + Send,
{
let stream = get_socket_rw_stream(path).await?;
let (mut read, write) = stream.into_split();
let tx = target.clone();
tokio::spawn(async move {
let mut read_buf = vec![0; BUFFER_SIZE];
loop {
match read.read(&mut read_buf).await {
Err(_) => return,
Ok(0) => {
let _ = tx.send(T::from_closed_server_bridge(index)).await;
return; // EOF
}
Ok(s) => {
let send = tx.send(T::from_server_message(index, &read_buf[..s])).await;
if send.is_err() {
return;
}
}
}
}
});
Ok(ServerBridge { write })
}
pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
self.write.write_all(&b).await?;
Ok(())
}
pub async fn close(mut self) -> std::io::Result<()> {
self.write.shutdown().await?;
Ok(())
}
}

View file

@ -0,0 +1,133 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::{path::Path, time::Duration};
use tokio::{
io::{self, Interest},
net::windows::named_pipe::{ClientOptions, NamedPipeClient},
sync::mpsc,
time::sleep,
};
use crate::util::errors::{wrap, AnyError};
pub struct ServerBridge {
write_tx: mpsc::Sender<Vec<u8>>,
}
pub trait FromServerMessage {
fn from_server_message(index: u16, message: &[u8]) -> Self;
fn from_closed_server_bridge(i: u16) -> Self;
}
const BUFFER_SIZE: usize = 65536;
pub async fn get_socket_rw_stream(path: &Path) -> Result<NamedPipeClient, AnyError> {
// Tokio says we can need to try in a loop. Do so.
// https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
let client = loop {
match ClientOptions::new().open(path) {
Ok(client) => break client,
// ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499-
Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await,
Err(e) => {
return Err(AnyError::WrappedError(wrap(
e,
format!(
"error connecting to vscode server socket in {}",
path.display()
),
)))
}
}
};
Ok(client)
}
impl ServerBridge {
pub async fn new<T>(path: &Path, index: u16, target: &mpsc::Sender<T>) -> Result<Self, AnyError>
where
T: 'static + FromServerMessage + Send,
{
let client = get_socket_rw_stream(path).await?;
let (write_tx, mut write_rx) = mpsc::channel(4);
let read_tx = target.clone();
tokio::spawn(async move {
let mut read_buf = vec![0; BUFFER_SIZE];
let mut pending_recv: Option<Vec<u8>> = None;
// See https://docs.rs/tokio/1.17.0/tokio/net/windows/named_pipe/struct.NamedPipeClient.html#method.ready
// With additional complications. If there's nothing queued to write, we wait for the
// pipe to be readable, or for something to come in. If there is something to
// write, wait until the pipe is either readable or writable.
loop {
let ready_result = if pending_recv.is_none() {
tokio::select! {
msg = write_rx.recv() => match msg {
Some(msg) => {
pending_recv = Some(msg);
client.ready(Interest::READABLE | Interest::WRITABLE).await
},
None => return
},
r = client.ready(Interest::READABLE) => r,
}
} else {
client.ready(Interest::READABLE | Interest::WRITABLE).await
};
let ready = match ready_result {
Ok(r) => r,
Err(_) => return,
};
if ready.is_readable() {
match client.try_read(&mut read_buf) {
Ok(0) => return, // EOF
Ok(s) => {
let send = read_tx
.send(T::from_server_message(index, &read_buf[..s]))
.await;
if send.is_err() {
return;
}
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
}
Err(_) => return,
}
}
if let Some(msg) = &pending_recv {
if ready.is_writable() {
match client.try_write(msg) {
Ok(n) if n == msg.len() => pending_recv = None,
Ok(n) => pending_recv = Some(msg[n..].to_vec()),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
continue;
}
Err(_) => return,
}
}
}
}
});
Ok(ServerBridge { write_tx })
}
pub async fn write(&self, b: Vec<u8>) -> std::io::Result<()> {
self.write_tx.send(b).await.ok();
Ok(())
}
pub async fn close(self) -> std::io::Result<()> {
drop(self.write_tx);
Ok(())
}
}

View file

@ -0,0 +1,81 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::path::PathBuf;
use async_trait::async_trait;
use tokio::sync::oneshot;
use crate::log;
use crate::state::LauncherPaths;
use crate::util::errors::AnyError;
pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log";
#[async_trait]
pub trait ServiceContainer: Send {
async fn run_service(
&mut self,
log: log::Logger,
launcher_paths: LauncherPaths,
shutdown_rx: oneshot::Receiver<()>,
) -> Result<(), AnyError>;
}
pub trait ServiceManager {
/// Registers the current executable as a service to run with the given set
/// of arguments.
fn register(&self, exe: PathBuf, args: &[&str]) -> Result<(), AnyError>;
/// Runs the service using the given handle. The executable *must not* take
/// any action which may fail prior to calling this to ensure service
/// states may update.
fn run(
&self,
launcher_paths: LauncherPaths,
handle: impl 'static + ServiceContainer,
) -> Result<(), AnyError>;
/// Unregisters the current executable as a service.
fn unregister(&self) -> Result<(), AnyError>;
}
#[cfg(target_os = "windows")]
pub type ServiceManagerImpl = super::service_windows::WindowsService;
#[cfg(not(target_os = "windows"))]
pub type ServiceManagerImpl = UnimplementedServiceManager;
#[allow(unreachable_code)]
pub fn create_service_manager(log: log::Logger) -> ServiceManagerImpl {
ServiceManagerImpl::new(log)
}
pub struct UnimplementedServiceManager();
#[allow(dead_code)]
impl UnimplementedServiceManager {
fn new(_log: log::Logger) -> Self {
Self()
}
}
impl ServiceManager for UnimplementedServiceManager {
fn register(&self, _exe: PathBuf, _args: &[&str]) -> Result<(), AnyError> {
unimplemented!("Service management is not supported on this platform");
}
fn run(
&self,
_launcher_paths: LauncherPaths,
_handle: impl 'static + ServiceContainer,
) -> Result<(), AnyError> {
unimplemented!("Service management is not supported on this platform");
}
fn unregister(&self) -> Result<(), AnyError> {
unimplemented!("Service management is not supported on this platform");
}
}

View file

@ -0,0 +1,278 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use dialoguer::{theme::ColorfulTheme, Input, Password};
use lazy_static::lazy_static;
use std::{ffi::OsString, sync::Mutex, thread, time::Duration};
use tokio::sync::oneshot;
use windows_service::{
define_windows_service,
service::{
ServiceAccess, ServiceControl, ServiceControlAccept, ServiceErrorControl, ServiceExitCode,
ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType,
},
service_control_handler::{self, ServiceControlHandlerResult},
service_dispatcher,
service_manager::{ServiceManager, ServiceManagerAccess},
};
use crate::util::errors::{wrap, AnyError, WindowsNeedsElevation};
use crate::{
log::{self, FileLogSink},
state::LauncherPaths,
};
use super::service::{
ServiceContainer, ServiceManager as CliServiceManager, SERVICE_LOG_FILE_NAME,
};
pub struct WindowsService {
log: log::Logger,
}
const SERVICE_NAME: &str = "code_tunnel";
const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS;
impl WindowsService {
pub fn new(log: log::Logger) -> Self {
Self { log }
}
}
impl CliServiceManager for WindowsService {
fn register(&self, exe: std::path::PathBuf, args: &[&str]) -> Result<(), AnyError> {
let service_manager = ServiceManager::local_computer(
None::<&str>,
ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE,
)
.map_err(|e| WindowsNeedsElevation(format!("error getting service manager: {}", e)))?;
let mut service_info = ServiceInfo {
name: OsString::from(SERVICE_NAME),
display_name: OsString::from("VS Code Tunnel"),
service_type: SERVICE_TYPE,
start_type: ServiceStartType::AutoStart,
error_control: ServiceErrorControl::Normal,
executable_path: exe,
launch_arguments: args.iter().map(OsString::from).collect(),
dependencies: vec![],
account_name: None,
account_password: None,
};
let existing_service = service_manager.open_service(
SERVICE_NAME,
ServiceAccess::QUERY_STATUS | ServiceAccess::START | ServiceAccess::CHANGE_CONFIG,
);
let service = if let Ok(service) = existing_service {
service
.change_config(&service_info)
.map_err(|e| wrap(e, "error updating existing service"))?;
service
} else {
loop {
let (username, password) = prompt_credentials()?;
service_info.account_name = Some(format!(".\\{}", username).into());
service_info.account_password = Some(password.into());
match service_manager.create_service(
&service_info,
ServiceAccess::CHANGE_CONFIG | ServiceAccess::START,
) {
Ok(service) => break service,
Err(windows_service::Error::Winapi(e)) if Some(1057) == e.raw_os_error() => {
error!(
self.log,
"Invalid username or password, please try again..."
);
}
Err(e) => return Err(wrap(e, "error registering service").into()),
}
}
};
service
.set_description("Service that runs `code tunnel` for access on vscode.dev")
.ok();
info!(self.log, "Successfully registered service...");
let status = service
.query_status()
.map(|s| s.current_state)
.unwrap_or(ServiceState::Stopped);
if status == ServiceState::Stopped {
service
.start::<&str>(&[])
.map_err(|e| wrap(e, "error starting service"))?;
}
info!(self.log, "Tunnel service successfully started");
Ok(())
}
#[allow(unused_must_use)] // triggers incorrectly on `define_windows_service!`
fn run(
&self,
launcher_paths: LauncherPaths,
handle: impl 'static + ServiceContainer,
) -> Result<(), AnyError> {
let log = match FileLogSink::new(
log::Level::Debug,
&launcher_paths.root().join(SERVICE_LOG_FILE_NAME),
) {
Ok(sink) => self.log.tee(sink),
Err(e) => {
warning!(self.log, "Failed to create service log file: {}", e);
self.log.clone()
}
};
// We put the handle into the global "impl" type and then take it out in
// my_service_main. This is needed just since we have to have that
// function at the root level, but need to pass in data later here...
SERVICE_IMPL.lock().unwrap().replace(ServiceImpl {
container: Box::new(handle),
launcher_paths,
log,
});
define_windows_service!(ffi_service_main, service_main);
service_dispatcher::start(SERVICE_NAME, ffi_service_main)
.map_err(|e| wrap(e, "error starting service dispatcher").into())
}
fn unregister(&self) -> Result<(), AnyError> {
let service_manager =
ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)
.map_err(|e| wrap(e, "error getting service manager"))?;
let service = service_manager.open_service(
SERVICE_NAME,
ServiceAccess::QUERY_STATUS | ServiceAccess::STOP | ServiceAccess::DELETE,
);
let service = match service {
Ok(service) => service,
// Service does not exist:
Err(windows_service::Error::Winapi(e)) if Some(1060) == e.raw_os_error() => {
return Ok(())
}
Err(e) => return Err(wrap(e, "error getting service handle").into()),
};
let service_status = service
.query_status()
.map_err(|e| wrap(e, "error getting service status"))?;
if service_status.current_state != ServiceState::Stopped {
service
.stop()
.map_err(|e| wrap(e, "error getting stopping service"))?;
while let Ok(ServiceState::Stopped) = service.query_status().map(|s| s.current_state) {
info!(self.log, "Polling for service to stop...");
thread::sleep(Duration::from_secs(1));
}
}
service
.delete()
.map_err(|e| wrap(e, "error deleting service"))?;
Ok(())
}
}
struct ServiceImpl {
container: Box<dyn ServiceContainer>,
launcher_paths: LauncherPaths,
log: log::Logger,
}
lazy_static! {
static ref SERVICE_IMPL: Mutex<Option<ServiceImpl>> = Mutex::new(None);
}
/// "main" function that the service calls in its own thread.
fn service_main(_arguments: Vec<OsString>) -> Result<(), AnyError> {
let mut service = SERVICE_IMPL.lock().unwrap().take().unwrap();
// Create a channel to be able to poll a stop event from the service worker loop.
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let mut shutdown_tx = Some(shutdown_tx);
// Define system service event handler that will be receiving service events.
let event_handler = move |control_event| -> ServiceControlHandlerResult {
match control_event {
ServiceControl::Interrogate => ServiceControlHandlerResult::NoError,
ServiceControl::Stop => {
shutdown_tx.take().and_then(|tx| tx.send(()).ok());
ServiceControlHandlerResult::NoError
}
_ => ServiceControlHandlerResult::NotImplemented,
}
};
let status_handle = service_control_handler::register(SERVICE_NAME, event_handler)
.map_err(|e| wrap(e, "error registering service event handler"))?;
// Tell the system that service is running
status_handle
.set_service_status(ServiceStatus {
service_type: SERVICE_TYPE,
current_state: ServiceState::Running,
controls_accepted: ServiceControlAccept::STOP,
exit_code: ServiceExitCode::Win32(0),
checkpoint: 0,
wait_hint: Duration::default(),
process_id: None,
})
.map_err(|e| wrap(e, "error marking service as running"))?;
let result = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(
service
.container
.run_service(service.log, service.launcher_paths, shutdown_rx),
);
status_handle
.set_service_status(ServiceStatus {
service_type: SERVICE_TYPE,
current_state: ServiceState::Stopped,
controls_accepted: ServiceControlAccept::empty(),
exit_code: ServiceExitCode::Win32(0),
checkpoint: 0,
wait_hint: Duration::default(),
process_id: None,
})
.map_err(|e| wrap(e, "error marking service as stopped"))?;
result
}
fn prompt_credentials() -> Result<(String, String), AnyError> {
println!("Running a Windows service under your user requires your username and password.");
println!("These are sent to the Windows Service Manager and are not stored by VS Code.");
let username: String = Input::with_theme(&ColorfulTheme::default())
.with_prompt("Windows username:")
.interact_text()
.map_err(|e| wrap(e, "Failed to read username"))?;
let password = Password::with_theme(&ColorfulTheme::default())
.with_prompt("Windows password:")
.interact()
.map_err(|e| wrap(e, "Failed to read password"))?;
Ok((username, password))
}

120
src/cli/src/update.rs Normal file
View file

@ -0,0 +1,120 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::constants::{LAUNCHER_ASSET_NAME, LAUNCHER_VERSION};
use crate::util::{errors, http, io::SilentCopyProgress};
use serde::Deserialize;
use std::{
fs::{rename, set_permissions},
path::Path,
};
pub struct Update {
client: reqwest::Client,
}
const LATEST_URL: &str = "https://aka.ms/vscode-server-launcher/update";
impl Default for Update {
fn default() -> Self {
Self::new()
}
}
impl Update {
// Creates a new Update instance without authentication
pub fn new() -> Update {
Update {
client: reqwest::Client::new(),
}
}
// Gets the asset to update to, or None if the current launcher is up to date.
pub async fn get_latest_release(&self) -> Result<LauncherRelease, errors::AnyError> {
let res = self
.client
.get(LATEST_URL)
.header(
"User-Agent",
format!(
"vscode-server-launcher/{}",
LAUNCHER_VERSION.unwrap_or("dev")
),
)
.send()
.await?;
if !res.status().is_success() {
return Err(errors::StatusError::from_res(res).await?.into());
}
Ok(res.json::<LauncherRelease>().await?)
}
pub async fn switch_to_release(
&self,
update: &LauncherRelease,
target_path: &Path,
) -> Result<(), errors::AnyError> {
let mut staging_path = target_path.to_owned();
staging_path.set_file_name(format!(
"{}.next",
target_path.file_name().unwrap().to_string_lossy()
));
let an = LAUNCHER_ASSET_NAME.unwrap();
let mut url = format!("{}/{}/{}", update.url, an, an);
if cfg!(target_os = "windows") {
url += ".exe";
}
let res = self.client.get(url).send().await?;
if !res.status().is_success() {
return Err(errors::StatusError::from_res(res).await?.into());
}
http::download_into_file(&staging_path, SilentCopyProgress(), res).await?;
copy_file_metadata(target_path, &staging_path)
.map_err(|e| errors::wrap(e, "failed to set file permissions"))?;
rename(&staging_path, &target_path)
.map_err(|e| errors::wrap(e, "failed to copy new launcher version"))?;
Ok(())
}
}
#[derive(Deserialize, Clone)]
pub struct LauncherRelease {
pub version: String,
pub url: String,
pub released_at: u64,
}
#[cfg(target_os = "windows")]
fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> {
let permissions = from.metadata()?.permissions();
set_permissions(&to, permissions)?;
Ok(())
}
#[cfg(not(target_os = "windows"))]
fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> {
use std::os::unix::ffi::OsStrExt;
use std::os::unix::fs::MetadataExt;
let metadata = from.metadata()?;
set_permissions(&to, metadata.permissions())?;
// based on coreutils' chown https://github.com/uutils/coreutils/blob/72b4629916abe0852ad27286f4e307fbca546b6e/src/chown/chown.rs#L266-L281
let s = std::ffi::CString::new(to.as_os_str().as_bytes()).unwrap();
let ret = unsafe { libc::chown(s.as_ptr(), metadata.uid(), metadata.gid()) };
if ret != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}

View file

@ -0,0 +1,264 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::path::Path;
use serde::Deserialize;
use crate::{
debug, log, options, spanf,
util::{
errors::{AnyError, StatusError, UnsupportedPlatformError, WrappedError},
io::ReportCopyProgress,
},
};
/// Implementation of the VS Code Update service for use in the CLI.
pub struct UpdateService {
client: reqwest::Client,
log: log::Logger,
}
/// Describes a specific release, can be created manually or returned from the update service.
pub struct Release {
pub platform: Platform,
pub target: TargetKind,
pub quality: options::Quality,
pub commit: String,
}
#[derive(Deserialize)]
struct UpdateServerVersion {
pub version: String,
}
fn quality_download_segment(quality: options::Quality) -> &'static str {
match quality {
options::Quality::Stable => "stable",
options::Quality::Insiders => "insider",
options::Quality::Exploration => "exploration",
}
}
impl UpdateService {
pub fn new(log: log::Logger, client: reqwest::Client) -> Self {
UpdateService { client, log }
}
pub async fn get_release_by_semver_version(
&self,
platform: Platform,
target: TargetKind,
quality: options::Quality,
version: &str,
) -> Result<Release, AnyError> {
let download_segment = target
.download_segment(platform)
.ok_or(UnsupportedPlatformError())?;
let download_url = format!(
"https://update.code.visualstudio.com/api/versions/{}/{}/{}",
version,
download_segment,
quality_download_segment(quality),
);
let response = spanf!(
self.log,
self.log.span("server.version.resolve"),
self.client.get(download_url).send()
)?;
if !response.status().is_success() {
return Err(StatusError::from_res(response).await?.into());
}
let res = response.json::<UpdateServerVersion>().await?;
debug!(self.log, "Resolved version {} to {}", version, res.version);
Ok(Release {
target,
platform,
quality,
commit: res.version,
})
}
/// Gets the latest commit for the target of the given quality.
pub async fn get_latest_commit(
&self,
platform: Platform,
target: TargetKind,
quality: options::Quality,
) -> Result<Release, AnyError> {
let download_segment = target
.download_segment(platform)
.ok_or(UnsupportedPlatformError())?;
let download_url = format!(
"https://update.code.visualstudio.com/api/latest/{}/{}",
download_segment,
quality_download_segment(quality),
);
let response = spanf!(
self.log,
self.log.span("server.version.resolve"),
self.client.get(download_url).send()
)?;
if !response.status().is_success() {
return Err(StatusError::from_res(response).await?.into());
}
let res = response.json::<UpdateServerVersion>().await?;
debug!(self.log, "Resolved quality {} to {}", quality, res.version);
Ok(Release {
target,
platform,
quality,
commit: res.version,
})
}
/// Gets the download stream for the release.
pub async fn get_download_stream(
&self,
release: &Release,
) -> Result<reqwest::Response, AnyError> {
let download_segment = release
.target
.download_segment(release.platform)
.ok_or(UnsupportedPlatformError())?;
let download_url = format!(
"https://update.code.visualstudio.com/commit:{}/{}/{}",
release.commit,
download_segment,
quality_download_segment(release.quality),
);
let response = reqwest::get(&download_url).await?;
if !response.status().is_success() {
return Err(StatusError::from_res(response).await?.into());
}
Ok(response)
}
}
pub fn unzip_downloaded_release<T>(
compressed_file: &Path,
target_dir: &Path,
reporter: T,
) -> Result<(), WrappedError>
where
T: ReportCopyProgress,
{
#[cfg(any(target_os = "windows", target_os = "macos"))]
{
use crate::util::zipper;
zipper::unzip_file(compressed_file, target_dir, reporter)
}
#[cfg(target_os = "linux")]
{
use crate::util::tar;
tar::decompress_tarball(compressed_file, target_dir, reporter)
}
}
#[derive(Eq, PartialEq, Copy, Clone)]
pub enum TargetKind {
Server,
Archive,
Web,
}
impl TargetKind {
fn download_segment(&self, platform: Platform) -> Option<String> {
match *self {
TargetKind::Server => Some(platform.headless()),
TargetKind::Archive => platform.archive(),
TargetKind::Web => Some(platform.web()),
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum Platform {
LinuxAlpineX64,
LinuxAlpineARM64,
LinuxX64,
LinuxARM64,
LinuxARM32,
DarwinX64,
DarwinARM64,
WindowsX64,
WindowsX86,
}
impl Platform {
pub fn archive(&self) -> Option<String> {
match self {
Platform::LinuxX64 => Some("linux-x64".to_owned()),
Platform::LinuxARM64 => Some("linux-arm64".to_owned()),
Platform::LinuxARM32 => Some("linux-armhf".to_owned()),
Platform::DarwinX64 => Some("darwin".to_owned()),
Platform::DarwinARM64 => Some("darwin-arm64".to_owned()),
Platform::WindowsX64 => Some("win32-x64-archive".to_owned()),
Platform::WindowsX86 => Some("win32-archive".to_owned()),
_ => None,
}
}
pub fn headless(&self) -> String {
match self {
Platform::LinuxAlpineARM64 => "server-alpine-arm64",
Platform::LinuxAlpineX64 => "server-linux-alpine",
Platform::LinuxX64 => "server-linux-x64",
Platform::LinuxARM64 => "server-linux-arm64",
Platform::LinuxARM32 => "server-linux-armhf",
Platform::DarwinX64 => "server-darwin",
Platform::DarwinARM64 => "server-darwin-arm64",
Platform::WindowsX64 => "server-win32-x64",
Platform::WindowsX86 => "server-win32",
}
.to_owned()
}
pub fn web(&self) -> String {
format!("{}-web", self.headless())
}
pub fn env_default() -> Option<Platform> {
if cfg!(all(
target_os = "linux",
target_arch = "x86_64",
target_env = "musl"
)) {
Some(Platform::LinuxAlpineX64)
} else if cfg!(all(
target_os = "linux",
target_arch = "aarch64",
target_env = "musl"
)) {
Some(Platform::LinuxAlpineARM64)
} else if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
Some(Platform::LinuxX64)
} else if cfg!(all(target_os = "linux", target_arch = "armhf")) {
Some(Platform::LinuxARM32)
} else if cfg!(all(target_os = "linux", target_arch = "aarch64")) {
Some(Platform::LinuxARM64)
} else if cfg!(all(target_os = "macos", target_arch = "x86_64")) {
Some(Platform::DarwinX64)
} else if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
Some(Platform::DarwinARM64)
} else if cfg!(all(target_os = "windows", target_arch = "x86_64")) {
Some(Platform::WindowsX64)
} else if cfg!(all(target_os = "windows", target_arch = "x86")) {
Some(Platform::WindowsX86)
} else {
None
}
}
}

19
src/cli/src/util.rs Normal file
View file

@ -0,0 +1,19 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
pub mod command;
pub mod errors;
pub mod http;
pub mod input;
pub mod io;
pub mod machine;
pub mod prereqs;
pub mod sync;
#[cfg(target_os = "linux")]
pub mod tar;
#[cfg(any(target_os = "windows", target_os = "macos"))]
pub mod zipper;

View file

@ -0,0 +1,77 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use super::errors::{wrap, WrappedError};
use std::{ffi::OsStr, process::Stdio};
use tokio::process::Command;
pub async fn capture_command<A, I, S>(
command_str: A,
args: I,
) -> Result<std::process::Output, WrappedError>
where
A: AsRef<OsStr>,
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
Command::new(&command_str)
.args(args)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.output()
.await
.map_err(|e| {
wrap(
e,
format!(
"failed to execute command '{}'",
(&command_str).as_ref().to_string_lossy()
),
)
})
}
/// Kills and processes and all of its children.
#[cfg(target_os = "windows")]
pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?;
Ok(())
}
/// Kills and processes and all of its children.
#[cfg(not(target_os = "windows"))]
pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> {
use futures::future::join_all;
use tokio::io::{AsyncBufReadExt, BufReader};
async fn kill_single_pid(process_id_str: String) {
capture_command("kill", &[&process_id_str]).await.ok();
}
// Rusty version of https://github.com/microsoft/vscode-js-debug/blob/main/src/targets/node/terminateProcess.sh
let parent_id = process_id.to_string();
let mut prgrep_cmd = Command::new("pgrep")
.arg("-P")
.arg(&parent_id)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.spawn()
.map_err(|e| wrap(e, "error enumerating process tree"))?;
let mut kill_futures = vec![tokio::spawn(
async move { kill_single_pid(parent_id).await },
)];
if let Some(stdout) = prgrep_cmd.stdout.take() {
let mut reader = BufReader::new(stdout).lines();
while let Some(line) = reader.next_line().await.unwrap_or(None) {
kill_futures.push(tokio::spawn(async move { kill_single_pid(line).await }))
}
}
join_all(kill_futures).await;
prgrep_cmd.kill().await.ok();
Ok(())
}

418
src/cli/src/util/errors.rs Normal file
View file

@ -0,0 +1,418 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::fmt::Display;
use crate::constants::CONTROL_PORT;
// Wraps another error with additional info.
#[derive(Debug, Clone)]
pub struct WrappedError {
message: String,
original: String,
}
impl std::fmt::Display for WrappedError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}: {}", self.message, self.original)
}
}
impl std::error::Error for WrappedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
impl WrappedError {
// fn new(original: Box<dyn std::error::Error>, message: String) -> WrappedError {
// WrappedError { message, original }
// }
}
impl From<reqwest::Error> for WrappedError {
fn from(e: reqwest::Error) -> WrappedError {
WrappedError {
message: format!(
"error requesting {}",
e.url().map_or("<unknown>", |u| u.as_str())
),
original: format!("{}", e),
}
}
}
pub fn wrap<T, S>(original: T, message: S) -> WrappedError
where
T: Display,
S: Into<String>,
{
WrappedError {
message: message.into(),
original: format!("{}", original),
}
}
// Error generated by an unsuccessful HTTP response
#[derive(Debug)]
pub struct StatusError {
url: String,
status_code: u16,
body: String,
}
impl std::fmt::Display for StatusError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"error requesting {}: {} {}",
self.url, self.status_code, self.body
)
}
}
impl StatusError {
pub async fn from_res(res: reqwest::Response) -> Result<StatusError, AnyError> {
let status_code = res.status().as_u16();
let url = res.url().to_string();
let body = res.text().await.map_err(|e| {
wrap(
e,
format!(
"failed to read response body on {} code from {}",
status_code, url
),
)
})?;
Ok(StatusError {
url,
status_code,
body,
})
}
}
// When the user has not consented to the licensing terms in using the Launcher
#[derive(Debug)]
pub struct MissingLegalConsent(pub String);
impl std::fmt::Display for MissingLegalConsent {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
// When the provided connection token doesn't match the one used to set up the original VS Code Server
// This is most likely due to a new user joining.
#[derive(Debug)]
pub struct MismatchConnectionToken(pub String);
impl std::fmt::Display for MismatchConnectionToken {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
// When the VS Code server has an unrecognized extension (rather than zip or gz)
#[derive(Debug)]
pub struct InvalidServerExtensionError(pub String);
impl std::fmt::Display for InvalidServerExtensionError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "invalid server extension '{}'", self.0)
}
}
// When the tunnel fails to open
#[derive(Debug, Clone)]
pub struct DevTunnelError(pub String);
impl std::fmt::Display for DevTunnelError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "could not open tunnel: {}", self.0)
}
}
impl std::error::Error for DevTunnelError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
// When the server was downloaded, but the entrypoint scripts don't exist.
#[derive(Debug)]
pub struct MissingEntrypointError();
impl std::fmt::Display for MissingEntrypointError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Missing entrypoints in server download. Most likely this is a corrupted download. Please retry")
}
}
#[derive(Debug)]
pub struct SetupError(pub String);
impl std::fmt::Display for SetupError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{}\r\n\r\nMore info at https://code.visualstudio.com/docs/remote/linux",
self.0
)
}
}
#[derive(Debug)]
pub struct NoHomeForLauncherError();
impl std::fmt::Display for NoHomeForLauncherError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"No $HOME variable was found in your environment. Either set it, or specify a `--data-dir` manually when invoking the launcher.",
)
}
}
#[derive(Debug)]
pub struct InvalidTunnelName(pub String);
impl std::fmt::Display for InvalidTunnelName {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
#[derive(Debug)]
pub struct TunnelCreationFailed(pub String, pub String);
impl std::fmt::Display for TunnelCreationFailed {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"Could not create tunnel with name: {}\nReason: {}",
&self.0, &self.1
)
}
}
#[derive(Debug)]
pub struct TunnelHostFailed(pub String);
impl std::fmt::Display for TunnelHostFailed {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
#[derive(Debug)]
pub struct ExtensionInstallFailed(pub String);
impl std::fmt::Display for ExtensionInstallFailed {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Extension install failed: {}", &self.0)
}
}
#[derive(Debug)]
pub struct MismatchedLaunchModeError();
impl std::fmt::Display for MismatchedLaunchModeError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "A server is already running, but it was not launched in the same listening mode (port vs. socket) as this request")
}
}
#[derive(Debug)]
pub struct NoAttachedServerError();
impl std::fmt::Display for NoAttachedServerError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "No server is running")
}
}
#[derive(Debug)]
pub struct ServerWriteError();
impl std::fmt::Display for ServerWriteError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Error writing to the server, it should be restarted")
}
}
#[derive(Debug)]
pub struct RefreshTokenNotAvailableError();
impl std::fmt::Display for RefreshTokenNotAvailableError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Refresh token not available, authentication is required")
}
}
#[derive(Debug)]
pub struct UnsupportedPlatformError();
impl std::fmt::Display for UnsupportedPlatformError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"This operation is not supported on your current platform"
)
}
}
#[derive(Debug)]
pub struct NoInstallInUserProvidedPath(pub String);
impl std::fmt::Display for NoInstallInUserProvidedPath {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"No VS Code installation could be found in {}. You can run `code --use-quality=stable` to switch to the latest stable version of VS Code.",
self.0
)
}
}
#[derive(Debug)]
pub struct InvalidRequestedVersion();
impl std::fmt::Display for InvalidRequestedVersion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"The reqested version is invalid, expected one of 'stable', 'insiders', version number (x.y.z), or absolute path.",
)
}
}
#[derive(Debug)]
pub struct UserCancelledInstallation();
impl std::fmt::Display for UserCancelledInstallation {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Installation aborted.")
}
}
#[derive(Debug)]
pub struct CannotForwardControlPort();
impl std::fmt::Display for CannotForwardControlPort {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Cannot forward or unforward port {}.", CONTROL_PORT)
}
}
#[derive(Debug)]
pub struct ServerHasClosed();
impl std::fmt::Display for ServerHasClosed {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Request cancelled because the server has closed")
}
}
#[derive(Debug)]
pub struct ServiceAlreadyRegistered();
impl std::fmt::Display for ServiceAlreadyRegistered {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Already registered the service. Run `code tunnel service uninstall` to unregister it first")
}
}
#[derive(Debug)]
pub struct WindowsNeedsElevation(pub String);
impl std::fmt::Display for WindowsNeedsElevation {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "{}", self.0)?;
writeln!(f)?;
writeln!(f, "You may need to run this command as an administrator:")?;
writeln!(f, " 1. Open the start menu and search for Powershell")?;
writeln!(f, " 2. Right click and 'Run as administrator'")?;
if let Ok(exe) = std::env::current_exe() {
writeln!(
f,
" 3. Run &'{}' '{}'",
exe.display(),
std::env::args().skip(1).collect::<Vec<_>>().join("' '")
)
} else {
writeln!(f, " 3. Run the same command again",)
}
}
}
// Makes an "AnyError" enum that contains any of the given errors, in the form
// `enum AnyError { FooError(FooError) }` (when given `makeAnyError!(FooError)`).
// Useful to easily deal with application error types without making tons of "From"
// clauses.
macro_rules! makeAnyError {
($($e:ident),*) => {
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
pub enum AnyError {
$($e($e),)*
}
impl std::fmt::Display for AnyError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
$(AnyError::$e(ref e) => e.fmt(f),)*
}
}
}
impl std::error::Error for AnyError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
$(impl From<$e> for AnyError {
fn from(e: $e) -> AnyError {
AnyError::$e(e)
}
})*
};
}
makeAnyError!(
MissingLegalConsent,
MismatchConnectionToken,
DevTunnelError,
StatusError,
WrappedError,
InvalidServerExtensionError,
MissingEntrypointError,
SetupError,
NoHomeForLauncherError,
TunnelCreationFailed,
TunnelHostFailed,
InvalidTunnelName,
ExtensionInstallFailed,
MismatchedLaunchModeError,
NoAttachedServerError,
ServerWriteError,
UnsupportedPlatformError,
RefreshTokenNotAvailableError,
NoInstallInUserProvidedPath,
UserCancelledInstallation,
InvalidRequestedVersion,
CannotForwardControlPort,
ServerHasClosed,
ServiceAlreadyRegistered,
WindowsNeedsElevation
);
impl From<reqwest::Error> for AnyError {
fn from(e: reqwest::Error) -> AnyError {
AnyError::WrappedError(WrappedError::from(e))
}
}

36
src/cli/src/util/http.rs Normal file
View file

@ -0,0 +1,36 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::util::errors::{self, WrappedError};
use futures::stream::TryStreamExt;
use tokio::fs;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use super::io::{copy_async_progress, ReportCopyProgress};
pub async fn download_into_file<T>(
filename: &std::path::Path,
progress: T,
res: reqwest::Response,
) -> Result<fs::File, WrappedError>
where
T: ReportCopyProgress,
{
let mut file = fs::File::create(filename)
.await
.map_err(|e| errors::wrap(e, "failed to create file"))?;
let content_length = res.content_length().unwrap_or(0);
let mut read = res
.bytes_stream()
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
.into_async_read()
.compat();
copy_async_progress(progress, &mut read, &mut file, content_length)
.await
.map_err(|e| errors::wrap(e, "failed to download file"))?;
Ok(file)
}

69
src/cli/src/util/input.rs Normal file
View file

@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::util::errors::wrap;
use dialoguer::{theme::ColorfulTheme, Confirm, Input, Select};
use indicatif::ProgressBar;
use std::fmt::Display;
use super::{errors::WrappedError, io::ReportCopyProgress};
/// Wrapper around indicatif::ProgressBar that implements ReportCopyProgress.
pub struct ProgressBarReporter {
bar: ProgressBar,
has_set_total: bool,
}
impl From<ProgressBar> for ProgressBarReporter {
fn from(bar: ProgressBar) -> Self {
ProgressBarReporter {
bar,
has_set_total: false,
}
}
}
impl ReportCopyProgress for ProgressBarReporter {
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) {
if !self.has_set_total {
self.bar.set_length(total_bytes);
}
if bytes_so_far == total_bytes {
self.bar.finish_and_clear();
} else {
self.bar.set_position(bytes_so_far);
}
}
}
pub fn prompt_yn(text: &str) -> Result<bool, WrappedError> {
Confirm::with_theme(&ColorfulTheme::default())
.with_prompt(text)
.default(true)
.interact()
.map_err(|e| wrap(e, "Failed to read confirm input"))
}
pub fn prompt_options<T>(text: &str, options: &[T]) -> Result<T, WrappedError>
where
T: Display + Copy,
{
let chosen = Select::with_theme(&ColorfulTheme::default())
.with_prompt(text)
.items(options)
.default(0)
.interact()
.map_err(|e| wrap(e, "Failed to read select input"))?;
Ok(options[chosen])
}
pub fn prompt_placeholder(question: &str, placeholder: &str) -> Result<String, WrappedError> {
Input::with_theme(&ColorfulTheme::default())
.with_prompt(question)
.default(placeholder.to_string())
.interact_text()
.map_err(|e| wrap(e, "Failed to read confirm input"))
}

59
src/cli/src/util/io.rs Normal file
View file

@ -0,0 +1,59 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub trait ReportCopyProgress {
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64);
}
/// Type that doesn't emit anything for download progress.
pub struct SilentCopyProgress();
impl ReportCopyProgress for SilentCopyProgress {
fn report_progress(&mut self, _bytes_so_far: u64, _total_bytes: u64) {}
}
/// Copies from the reader to the writer, reporting progress to the provided
/// reporter every so often.
pub async fn copy_async_progress<T, R, W>(
mut reporter: T,
reader: &mut R,
writer: &mut W,
total_bytes: u64,
) -> io::Result<u64>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
T: ReportCopyProgress,
{
let mut buf = vec![0; 8 * 1024];
let mut bytes_so_far = 0;
let mut bytes_last_reported = 0;
let report_granularity = std::cmp::min(total_bytes / 10, 2 * 1024 * 1024);
reporter.report_progress(0, total_bytes);
loop {
let read_buf = match reader.read(&mut buf).await {
Ok(0) => break,
Ok(n) => &buf[..n],
Err(e) => return Err(e),
};
writer.write_all(read_buf).await?;
bytes_so_far += read_buf.len() as u64;
if bytes_so_far - bytes_last_reported > report_granularity {
bytes_last_reported = bytes_so_far;
reporter.report_progress(bytes_so_far, total_bytes);
}
}
reporter.report_progress(bytes_so_far, total_bytes);
Ok(bytes_so_far)
}

View file

@ -0,0 +1,78 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::util::errors;
use std::path::Path;
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
pub fn process_at_path_exists(pid: u32, name: &Path) -> bool {
// TODO https://docs.rs/sysinfo/latest/sysinfo/index.html#usage
let mut sys = System::new_all();
sys.refresh_processes();
let name_str = format!("{}", name.display());
match sys.process(Pid::from_u32(pid)) {
Some(process) => {
for cmd in process.cmd() {
if cmd.contains(&name_str) {
return true;
}
}
}
None => {
return false;
}
}
false
}
pub fn process_exists(pid: u32) -> bool {
let mut sys = System::new_all();
sys.refresh_processes();
sys.process(Pid::from_u32(pid)).is_some()
}
pub fn find_running_process(name: &Path) -> Option<u32> {
// TODO https://docs.rs/sysinfo/latest/sysinfo/index.html#usage
let mut sys = System::new_all();
sys.refresh_processes();
let name_str = format!("{}", name.display());
for (pid, process) in sys.processes() {
for cmd in process.cmd() {
if cmd.contains(&name_str) {
return Some(pid.as_u32());
}
}
}
None
}
#[cfg(not(target_family = "unix"))]
pub async fn set_executable_permission<P: AsRef<std::path::Path>>(
_file: P,
) -> Result<(), errors::WrappedError> {
Ok(())
}
#[cfg(target_family = "unix")]
pub async fn set_executable_permission<P: AsRef<std::path::Path>>(
file: P,
) -> Result<(), errors::WrappedError> {
use std::os::unix::prelude::PermissionsExt;
let mut permissions = tokio::fs::metadata(&file)
.await
.map_err(|e| errors::wrap(e, "failed to read executable file metadata"))?
.permissions();
permissions.set_mode(0o750);
tokio::fs::set_permissions(&file, permissions)
.await
.map_err(|e| errors::wrap(e, "failed to set executable permissions"))?;
Ok(())
}

301
src/cli/src/util/prereqs.rs Normal file
View file

@ -0,0 +1,301 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use std::cmp::Ordering;
use super::command::capture_command;
use crate::update_service::Platform;
use crate::util::errors::SetupError;
use lazy_static::lazy_static;
use regex::bytes::Regex as BinRegex;
use regex::Regex;
use tokio::fs;
use super::errors::AnyError;
lazy_static! {
static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap();
static ref LDD_VERSION_RE: BinRegex = BinRegex::new(r"^ldd.*(.+)\.(.+)\s").unwrap();
static ref LIBSTD_CXX_VERSION_RE: BinRegex =
BinRegex::new(r"GLIBCXX_([0-9]+)\.([0-9]+)(?:\.([0-9]+))?").unwrap();
static ref MIN_CXX_VERSION: SimpleSemver = SimpleSemver::new(3, 4, 18);
static ref MIN_LDD_VERSION: SimpleSemver = SimpleSemver::new(2, 17, 0);
}
pub struct PreReqChecker {}
impl Default for PreReqChecker {
fn default() -> Self {
Self::new()
}
}
impl PreReqChecker {
pub fn new() -> PreReqChecker {
PreReqChecker {}
}
#[cfg(not(target_os = "linux"))]
pub async fn verify(&self) -> Result<Platform, AnyError> {
Platform::env_default().ok_or_else(|| {
SetupError("VS Code it not supported on this platform".to_owned()).into()
})
}
#[cfg(target_os = "linux")]
pub async fn verify(&self) -> Result<Platform, AnyError> {
let (gnu_a, gnu_b, or_musl) = tokio::join!(
check_glibc_version(),
check_glibcxx_version(),
check_musl_interpreter()
);
if gnu_a.is_ok() && gnu_b.is_ok() {
return Ok(if cfg!(target_arch = "x86_64") {
Platform::LinuxX64
} else if cfg!(target_arch = "armhf") {
Platform::LinuxARM32
} else {
Platform::LinuxARM64
});
}
if or_musl.is_ok() {
return Ok(if cfg!(target_arch = "x86_64") {
Platform::LinuxAlpineX64
} else {
Platform::LinuxAlpineARM64
});
}
let mut errors: Vec<String> = vec![];
if let Err(e) = gnu_a {
errors.push(e);
} else if let Err(e) = gnu_b {
errors.push(e);
}
if let Err(e) = or_musl {
errors.push(e);
}
let bullets = errors
.iter()
.map(|e| format!(" - {}", e))
.collect::<Vec<String>>()
.join("\n");
Err(AnyError::from(SetupError(format!(
"This machine not meet VS Code Server's prerequisites, expected either...\n{}",
bullets,
))))
}
}
#[allow(dead_code)]
async fn check_musl_interpreter() -> Result<(), String> {
const MUSL_PATH: &str = if cfg!(target_platform = "aarch64") {
"/lib/ld-musl-aarch64.so.1"
} else {
"/lib/ld-musl-x86_64.so.1"
};
if fs::metadata(MUSL_PATH).await.is_err() {
return Err(format!(
"find {}, which is required to run the VS Code Server in musl environments",
MUSL_PATH
));
}
Ok(())
}
#[allow(dead_code)]
async fn check_glibc_version() -> Result<(), String> {
let ldd_version = capture_command("ldd", ["--version"])
.await
.ok()
.and_then(|o| extract_ldd_version(&o.stdout));
if let Some(v) = ldd_version {
return if v.gte(&MIN_LDD_VERSION) {
Ok(())
} else {
Err(format!(
"find GLIBC >= 2.17 (but found {} instead) for GNU environments",
v
))
};
}
Ok(())
}
#[allow(dead_code)]
async fn check_glibcxx_version() -> Result<(), String> {
let mut libstdc_path: Option<String> = None;
const DEFAULT_LIB_PATH: &str = "/usr/lib64/libstdc++.so.6";
const LDCONFIG_PATH: &str = "/sbin/ldconfig";
if fs::metadata(DEFAULT_LIB_PATH).await.is_ok() {
libstdc_path = Some(DEFAULT_LIB_PATH.to_owned());
} else if fs::metadata(LDCONFIG_PATH).await.is_ok() {
libstdc_path = capture_command(LDCONFIG_PATH, ["-p"])
.await
.ok()
.and_then(|o| extract_libstd_from_ldconfig(&o.stdout));
}
match libstdc_path {
Some(path) => match fs::read(&path).await {
Ok(contents) => check_for_sufficient_glibcxx_versions(contents),
Err(e) => Err(format!(
"validate GLIBCXX version for GNU environments, but could not: {}",
e
)),
},
None => Err("find libstdc++.so or ldconfig for GNU environments".to_owned()),
}
}
#[allow(dead_code)]
fn check_for_sufficient_glibcxx_versions(contents: Vec<u8>) -> Result<(), String> {
let all_versions: Vec<SimpleSemver> = LIBSTD_CXX_VERSION_RE
.captures_iter(&contents)
.map(|m| SimpleSemver {
major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())),
minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())),
patch: m.get(3).map_or(0, |s| u32_from_bytes(s.as_bytes())),
})
.collect();
if !all_versions.iter().any(|v| MIN_CXX_VERSION.gte(v)) {
return Err(format!(
"find GLIBCXX >= 3.4.18 (but found {} instead) for GNU environments",
all_versions
.iter()
.map(String::from)
.collect::<Vec<String>>()
.join(", ")
));
}
Ok(())
}
fn extract_ldd_version(output: &[u8]) -> Option<SimpleSemver> {
LDD_VERSION_RE.captures(output).map(|m| SimpleSemver {
major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())),
minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())),
patch: 0,
})
}
fn extract_libstd_from_ldconfig(output: &[u8]) -> Option<String> {
String::from_utf8_lossy(output)
.lines()
.find_map(|l| LDCONFIG_STDC_RE.captures(l))
.and_then(|cap| cap.get(1))
.map(|cap| cap.as_str().to_owned())
}
fn u32_from_bytes(b: &[u8]) -> u32 {
String::from_utf8_lossy(b).parse::<u32>().unwrap_or(0)
}
#[derive(Debug, PartialEq)]
struct SimpleSemver {
major: u32,
minor: u32,
patch: u32,
}
impl From<&SimpleSemver> for String {
fn from(s: &SimpleSemver) -> Self {
format!("v{}.{}.{}", s.major, s.minor, s.patch)
}
}
impl std::fmt::Display for SimpleSemver {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", String::from(self))
}
}
#[allow(dead_code)]
impl SimpleSemver {
fn new(major: u32, minor: u32, patch: u32) -> SimpleSemver {
SimpleSemver {
major,
minor,
patch,
}
}
fn gte(&self, other: &SimpleSemver) -> bool {
match self.major.cmp(&other.major) {
Ordering::Greater => true,
Ordering::Less => false,
Ordering::Equal => match self.minor.cmp(&other.minor) {
Ordering::Greater => true,
Ordering::Less => false,
Ordering::Equal => self.patch >= other.patch,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_libstd_from_ldconfig() {
let actual = "
libstoken.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstoken.so.1
libstemmer.so.0d (libc6,x86-64) => /lib/x86_64-linux-gnu/libstemmer.so.0d
libstdc++.so.6 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstdc++.so.6
libstartup-notification-1.so.0 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstartup-notification-1.so.0
libssl3.so (libc6,x86-64) => /lib/x86_64-linux-gnu/libssl3.so
".to_owned().into_bytes();
assert_eq!(
extract_libstd_from_ldconfig(&actual),
Some("/lib/x86_64-linux-gnu/libstdc++.so.6".to_owned()),
);
assert_eq!(
extract_libstd_from_ldconfig(&"nothing here!".to_owned().into_bytes()),
None,
);
}
#[test]
fn test_gte() {
assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 2, 3)));
assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(0, 10, 10)));
assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 1, 10)));
assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 2, 10)));
assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 3, 1)));
assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(2, 2, 1)));
}
#[test]
fn check_for_sufficient_glibcxx_versions() {
let actual = "ldd (Ubuntu GLIBC 2.31-0ubuntu9.7) 2.31
Copyright (C) 2020 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
Written by Roland McGrath and Ulrich Drepper."
.to_owned()
.into_bytes();
assert_eq!(
extract_ldd_version(&actual),
Some(SimpleSemver::new(2, 31, 0)),
);
}
}

89
src/cli/src/util/sync.rs Normal file
View file

@ -0,0 +1,89 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use tokio::sync::watch::{
self,
error::{RecvError, SendError},
};
#[derive(Clone)]
pub struct Barrier<T>(watch::Receiver<Option<T>>)
where
T: Copy;
impl<T> Barrier<T>
where
T: Copy,
{
/// Waits for the barrier to be closed, returning a value if one was sent.
pub async fn wait(&mut self) -> Result<T, RecvError> {
loop {
if let Err(e) = self.0.changed().await {
return Err(e);
}
if let Some(v) = *(self.0.borrow()) {
return Ok(v);
}
}
}
}
pub struct BarrierOpener<T>(watch::Sender<Option<T>>);
impl<T> BarrierOpener<T> {
/// Closes the barrier.
pub fn open(self, value: T) -> Result<(), SendError<Option<T>>> {
self.0.send(Some(value))
}
}
/// The Barrier is something that can be opened once from one side,
/// and is thereafter permanently closed. It can contain a value.
pub fn new_barrier<T>() -> (Barrier<T>, BarrierOpener<T>)
where
T: Copy,
{
let (closed_tx, closed_rx) = watch::channel(None);
(Barrier(closed_rx), BarrierOpener(closed_tx))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_barrier_close_after_spawn() {
let (mut barrier, opener) = new_barrier::<u32>();
let (tx, rx) = tokio::sync::oneshot::channel::<u32>();
tokio::spawn(async move {
tx.send(barrier.wait().await.unwrap()).unwrap();
});
opener.open(42).unwrap();
assert!(rx.await.unwrap() == 42);
}
#[tokio::test]
async fn test_barrier_close_before_spawn() {
let (barrier, opener) = new_barrier::<u32>();
let (tx1, rx1) = tokio::sync::oneshot::channel::<u32>();
let (tx2, rx2) = tokio::sync::oneshot::channel::<u32>();
opener.open(42).unwrap();
let mut b1 = barrier.clone();
tokio::spawn(async move {
tx1.send(b1.wait().await.unwrap()).unwrap();
});
let mut b2 = barrier.clone();
tokio::spawn(async move {
tx2.send(b2.wait().await.unwrap()).unwrap();
});
assert!(rx1.await.unwrap() == 42);
assert!(rx2.await.unwrap() == 42);
}
}

52
src/cli/src/util/tar.rs Normal file
View file

@ -0,0 +1,52 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use crate::util::errors::{wrap, WrappedError};
use flate2::read::GzDecoder;
use std::fs::File;
use std::path::{Path, PathBuf};
use tar::Archive;
use super::io::ReportCopyProgress;
pub fn decompress_tarball<T>(
path: &Path,
parent_path: &Path,
mut reporter: T,
) -> Result<(), WrappedError>
where
T: ReportCopyProgress,
{
let tar_gz = File::open(path).map_err(|e| {
wrap(
Box::new(e),
format!("error opening file {}", path.display()),
)
})?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
let results = archive
.entries()
.map_err(|e| wrap(e, format!("error opening archive {}", path.display())))?
.filter_map(|e| e.ok())
.map(|mut entry| {
let entry_path = entry
.path()
.map_err(|e| wrap(e, "error reading entry path"))?;
let path = parent_path.join(entry_path.iter().skip(1).collect::<PathBuf>());
entry
.unpack(&path)
.map_err(|e| wrap(e, format!("error unpacking {}", path.display())))?;
Ok(path)
})
.collect::<Result<Vec<PathBuf>, WrappedError>>()?;
// Tarballs don't have a way to get the number of entries ahead of time
reporter.report_progress(results.len() as u64, results.len() as u64);
Ok(())
}

155
src/cli/src/util/zipper.rs Normal file
View file

@ -0,0 +1,155 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
use super::errors::{wrap, WrappedError};
use super::io::ReportCopyProgress;
use std::fs::{self, File};
use std::io;
use std::path::Path;
use std::path::PathBuf;
use zip::read::ZipFile;
use zip::{self, ZipArchive};
// Borrowed and modified from https://github.com/zip-rs/zip/blob/master/examples/extract.rs
/// Returns whether all files in the archive start with the same path segment.
/// If so, it's an indication we should skip that segment when extracting.
fn should_skip_first_segment(archive: &mut ZipArchive<File>) -> bool {
let first_name = {
let file = archive
.by_index_raw(0)
.expect("expected not to have an empty archive");
let path = file
.enclosed_name()
.expect("expected to have path")
.iter()
.next()
.expect("expected to have non-empty name");
path.to_owned()
};
for i in 1..archive.len() {
if let Ok(file) = archive.by_index_raw(i) {
if let Some(name) = file.enclosed_name() {
if name.iter().next() != Some(&first_name) {
return false;
}
}
}
}
true
}
pub fn unzip_file<T>(path: &Path, parent_path: &Path, mut reporter: T) -> Result<(), WrappedError>
where
T: ReportCopyProgress,
{
let file = fs::File::open(path)
.map_err(|e| wrap(e, format!("unable to open file {}", path.display())))?;
let mut archive = zip::ZipArchive::new(file)
.map_err(|e| wrap(e, format!("failed to open zip archive {}", path.display())))?;
let skip_segments_no = if should_skip_first_segment(&mut archive) {
1
} else {
0
};
for i in 0..archive.len() {
reporter.report_progress(i as u64, archive.len() as u64);
let mut file = archive
.by_index(i)
.map_err(|e| wrap(e, format!("could not open zip entry {}", i)))?;
let outpath: PathBuf = match file.enclosed_name() {
Some(path) => {
let mut full_path = PathBuf::from(parent_path);
full_path.push(PathBuf::from_iter(path.iter().skip(skip_segments_no)));
full_path
}
None => continue,
};
if file.is_dir() || file.name().ends_with('/') {
fs::create_dir_all(&outpath)
.map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?;
apply_permissions(&file, &outpath)?;
continue;
}
if let Some(p) = outpath.parent() {
fs::create_dir_all(&p)
.map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?;
}
#[cfg(unix)]
{
use libc::S_IFLNK;
use std::io::Read;
use std::os::unix::ffi::OsStringExt;
if matches!(file.unix_mode(), Some(mode) if mode & (S_IFLNK as u32) == (S_IFLNK as u32))
{
let mut link_to = Vec::new();
file.read_to_end(&mut link_to).map_err(|e| {
wrap(
e,
format!("could not read symlink linkpath {}", outpath.display()),
)
})?;
let link_path = PathBuf::from(std::ffi::OsString::from_vec(link_to));
std::os::unix::fs::symlink(link_path, &outpath).map_err(|e| {
wrap(e, format!("could not create symlink {}", outpath.display()))
})?;
continue;
}
}
let mut outfile = fs::File::create(&outpath).map_err(|e| {
wrap(
e,
format!(
"unable to open file to write {} (from {:?})",
outpath.display(),
file.enclosed_name().map(|p| p.to_string_lossy()),
),
)
})?;
io::copy(&mut file, &mut outfile)
.map_err(|e| wrap(e, format!("error copying file {}", outpath.display())))?;
apply_permissions(&file, &outpath)?;
}
reporter.report_progress(archive.len() as u64, archive.len() as u64);
Ok(())
}
#[cfg(unix)]
fn apply_permissions(file: &ZipFile, outpath: &Path) -> Result<(), WrappedError> {
use std::os::unix::fs::PermissionsExt;
if let Some(mode) = file.unix_mode() {
fs::set_permissions(&outpath, fs::Permissions::from_mode(mode)).map_err(|e| {
wrap(
e,
format!("error setting permissions on {}", outpath.display()),
)
})?;
}
Ok(())
}
#[cfg(windows)]
fn apply_permissions(_file: &ZipFile, _outpath: &Path) -> Result<(), WrappedError> {
Ok(())
}