From c7f468d33b5d0814b56036639eb2a8226d4bfbbf Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 24 Jul 2024 13:20:06 -0700 Subject: [PATCH] fix(ext/fetch): use correct ALPN to proxies (#24696) Sending ALPN to a proxy, and then when tunneling, requires better juggling of TLS configs. This improves the choice of TLS config in the proxy connector, based on what reqwest does. It also includes some `ext/fetch/tests.rs` that check the different combinations. Fixes #24632 Fixes #24691 --- Cargo.lock | 1 + ext/fetch/Cargo.toml | 1 + ext/fetch/lib.rs | 20 +++-- ext/fetch/proxy.rs | 61 +++++++------- ext/fetch/tests.rs | 188 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 235 insertions(+), 36 deletions(-) create mode 100644 ext/fetch/tests.rs diff --git a/Cargo.lock b/Cargo.lock index db31a66380..341e1cdba7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1482,6 +1482,7 @@ dependencies = [ "hyper-rustls", "hyper-util", "ipnet", + "rustls-webpki", "serde", "serde_json", "tokio", diff --git a/ext/fetch/Cargo.toml b/ext/fetch/Cargo.toml index bb8aed367f..c28d88b240 100644 --- a/ext/fetch/Cargo.toml +++ b/ext/fetch/Cargo.toml @@ -27,6 +27,7 @@ hyper.workspace = true hyper-rustls.workspace = true hyper-util.workspace = true ipnet.workspace = true +rustls-webpki.workspace = true serde.workspace = true serde_json.workspace = true tokio.workspace = true diff --git a/ext/fetch/lib.rs b/ext/fetch/lib.rs index 1372329c4f..9912ff3072 100644 --- a/ext/fetch/lib.rs +++ b/ext/fetch/lib.rs @@ -2,6 +2,8 @@ mod fs_fetch_handler; mod proxy; +#[cfg(test)] +mod tests; use std::borrow::Cow; use std::cell::RefCell; @@ -62,7 +64,6 @@ use http::Method; use http::Uri; use http_body_util::BodyExt; use hyper::body::Frame; -use hyper_rustls::HttpsConnector; use hyper_util::client::legacy::connect::HttpConnector; use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioIo; @@ -975,6 +976,10 @@ pub fn create_http_client( deno_tls::SocketUse::Http, )?; + // Proxy TLS should not send ALPN + tls_config.alpn_protocols.clear(); + let proxy_tls_config = Arc::from(tls_config.clone()); + let mut alpn_protocols = vec![]; if options.http2 { alpn_protocols.push("h2".into()); @@ -987,7 +992,6 @@ pub fn create_http_client( let mut http_connector = HttpConnector::new(); http_connector.enforce_http(false); - let connector = HttpsConnector::from((http_connector, tls_config.clone())); let user_agent = user_agent .parse::() @@ -1008,9 +1012,13 @@ pub fn create_http_client( proxies.prepend(intercept); } let proxies = Arc::new(proxies); - let mut connector = - proxy::ProxyConnector::new(proxies.clone(), connector, tls_config); - connector.user_agent(user_agent.clone()); + let connector = proxy::ProxyConnector { + http: http_connector, + proxies: proxies.clone(), + tls: tls_config, + tls_proxy: proxy_tls_config, + user_agent: Some(user_agent.clone()), + }; if let Some(pool_max_idle_per_host) = options.pool_max_idle_per_host { builder.pool_max_idle_per_host(pool_max_idle_per_host); @@ -1059,7 +1067,7 @@ pub struct Client { user_agent: HeaderValue, } -type Connector = proxy::ProxyConnector>; +type Connector = proxy::ProxyConnector; // clippy is wrong here #[allow(clippy::declare_interior_mutable_const)] diff --git a/ext/fetch/proxy.rs b/ext/fetch/proxy.rs index db187c3f68..c8e54d5ec6 100644 --- a/ext/fetch/proxy.rs +++ b/ext/fetch/proxy.rs @@ -17,6 +17,8 @@ use deno_tls::rustls::ClientConfig as TlsConfig; use http::header::HeaderValue; use http::uri::Scheme; use http::Uri; +use hyper_rustls::HttpsConnector; +use hyper_rustls::MaybeHttpsStream; use hyper_util::client::legacy::connect::Connected; use hyper_util::client::legacy::connect::Connection; use hyper_util::rt::TokioIo; @@ -29,10 +31,14 @@ use tower_service::Service; #[derive(Debug, Clone)] pub(crate) struct ProxyConnector { - connector: C, - proxies: Arc, - tls: Arc, - user_agent: Option, + pub(crate) http: C, + pub(crate) proxies: Arc, + /// TLS config when destination is not a proxy + pub(crate) tls: Arc, + /// TLS config when destination is a proxy + /// Notably, does not include ALPN + pub(crate) tls_proxy: Arc, + pub(crate) user_agent: Option, } #[derive(Debug)] @@ -361,23 +367,6 @@ impl DomainMatcher { } impl ProxyConnector { - pub(crate) fn new( - proxies: Arc, - connector: C, - tls: Arc, - ) -> Self { - ProxyConnector { - connector, - proxies, - tls, - user_agent: None, - } - } - - pub(crate) fn user_agent(&mut self, val: HeaderValue) { - self.user_agent = Some(val); - } - fn intercept(&self, dst: &Uri) -> Option<&Intercept> { self.proxies.intercept(dst) } @@ -438,12 +427,13 @@ pub enum Proxied { impl Service for ProxyConnector where - C: Service, - C::Response: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + C: Service + Clone, + C::Response: + hyper::rt::Read + hyper::rt::Write + Connection + Unpin + Send + 'static, C::Future: Send + 'static, C::Error: Into + 'static, { - type Response = Proxied; + type Response = Proxied>; type Error = BoxError; type Future = BoxFuture>; @@ -451,7 +441,7 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll> { - self.connector.poll_ready(cx).map_err(Into::into) + self.http.poll_ready(cx).map_err(Into::into) } fn call(&mut self, orig_dst: Uri) -> Self::Future { @@ -467,10 +457,12 @@ where dst: proxy_dst, auth, } => { - let connecting = self.connector.call(proxy_dst); + let mut connector = + HttpsConnector::from((self.http.clone(), self.tls_proxy.clone())); + let connecting = connector.call(proxy_dst); let tls = TlsConnector::from(self.tls.clone()); Box::pin(async move { - let mut io = connecting.await.map_err(Into::into)?; + let mut io = connecting.await.map_err(Into::::into)?; if is_https { tunnel(&mut io, &orig_dst, user_agent, auth).await?; @@ -529,9 +521,11 @@ where } }; } + + let mut connector = + HttpsConnector::from((self.http.clone(), self.tls.clone())); Box::pin( - self - .connector + connector .call(orig_dst) .map_ok(Proxied::PassThrough) .map_err(Into::into), @@ -721,7 +715,14 @@ where match self { Proxied::PassThrough(ref p) => p.connected(), Proxied::HttpForward(ref p) => p.connected().proxy(true), - Proxied::HttpTunneled(ref p) => p.inner().get_ref().0.connected(), + Proxied::HttpTunneled(ref p) => { + let tunneled_tls = p.inner().get_ref(); + if tunneled_tls.1.alpn_protocol() == Some(b"h2") { + tunneled_tls.0.connected().negotiated_h2() + } else { + tunneled_tls.0.connected() + } + } Proxied::Socks(ref p) => p.connected(), Proxied::SocksTls(ref p) => p.inner().get_ref().0.connected(), } diff --git a/ext/fetch/tests.rs b/ext/fetch/tests.rs new file mode 100644 index 0000000000..c99a08d34c --- /dev/null +++ b/ext/fetch/tests.rs @@ -0,0 +1,188 @@ +// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. + +use std::net::SocketAddr; +use std::sync::Arc; + +use bytes::Bytes; +use http_body_util::BodyExt; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; + +use super::create_http_client; +use super::CreateHttpClientOptions; + +static EXAMPLE_CRT: &[u8] = include_bytes!("../tls/testdata/example1_cert.der"); +static EXAMPLE_KEY: &[u8] = + include_bytes!("../tls/testdata/example1_prikey.der"); + +#[tokio::test] +async fn test_https_proxy_http11() { + let src_addr = create_https_server(false).await; + let prx_addr = create_http_proxy(src_addr).await; + run_test_client(prx_addr, src_addr, false, http::Version::HTTP_11).await; +} + +#[tokio::test] +async fn test_https_proxy_h2() { + let src_addr = create_https_server(true).await; + let prx_addr = create_http_proxy(src_addr).await; + run_test_client(prx_addr, src_addr, false, http::Version::HTTP_2).await; +} + +#[tokio::test] +async fn test_https_proxy_https_h2() { + let src_addr = create_https_server(true).await; + let prx_addr = create_https_proxy(src_addr).await; + run_test_client(prx_addr, src_addr, true, http::Version::HTTP_2).await; +} + +async fn run_test_client( + prx_addr: SocketAddr, + src_addr: SocketAddr, + https: bool, + ver: http::Version, +) { + let client = create_http_client( + "fetch/test", + CreateHttpClientOptions { + root_cert_store: None, + ca_certs: vec![], + proxy: Some(deno_tls::Proxy { + url: format!("http{}://{}", if https { "s" } else { "" }, prx_addr), + basic_auth: None, + }), + unsafely_ignore_certificate_errors: Some(vec![]), + client_cert_chain_and_key: None, + pool_max_idle_per_host: None, + pool_idle_timeout: None, + http1: true, + http2: true, + }, + ) + .unwrap(); + + let req = http::Request::builder() + .uri(format!("https://{}/foo", src_addr)) + .body( + http_body_util::Empty::new() + .map_err(|err| match err {}) + .boxed(), + ) + .unwrap(); + let resp = client.send(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + assert_eq!(resp.version(), ver); + let hello = resp.collect().await.unwrap().to_bytes(); + assert_eq!(hello, "hello from server"); +} + +async fn create_https_server(allow_h2: bool) -> SocketAddr { + let mut tls_config = deno_tls::rustls::server::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert( + vec![EXAMPLE_CRT.into()], + webpki::types::PrivateKeyDer::try_from(EXAMPLE_KEY).unwrap(), + ) + .unwrap(); + if allow_h2 { + tls_config.alpn_protocols.push("h2".into()); + } + tls_config.alpn_protocols.push("http/1.1".into()); + let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::from(tls_config)); + let src_tcp = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let src_addr = src_tcp.local_addr().unwrap(); + + tokio::spawn(async move { + while let Ok((sock, _)) = src_tcp.accept().await { + let conn = tls_acceptor.accept(sock).await.unwrap(); + if conn.get_ref().1.alpn_protocol() == Some(b"h2") { + let fut = hyper::server::conn::http2::Builder::new( + hyper_util::rt::TokioExecutor::new(), + ) + .serve_connection( + hyper_util::rt::TokioIo::new(conn), + hyper::service::service_fn(|_req| async { + Ok::<_, std::convert::Infallible>(http::Response::new( + http_body_util::Full::::new("hello from server".into()), + )) + }), + ); + tokio::spawn(fut); + } else { + let fut = hyper::server::conn::http1::Builder::new().serve_connection( + hyper_util::rt::TokioIo::new(conn), + hyper::service::service_fn(|_req| async { + Ok::<_, std::convert::Infallible>(http::Response::new( + http_body_util::Full::::new("hello from server".into()), + )) + }), + ); + tokio::spawn(fut); + } + } + }); + + src_addr +} + +async fn create_http_proxy(src_addr: SocketAddr) -> SocketAddr { + let prx_tcp = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let prx_addr = prx_tcp.local_addr().unwrap(); + + tokio::spawn(async move { + while let Ok((mut sock, _)) = prx_tcp.accept().await { + let fut = async move { + let mut buf = [0u8; 4096]; + let _n = sock.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..7], b"CONNECT"); + let mut dst_tcp = + tokio::net::TcpStream::connect(src_addr).await.unwrap(); + sock.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); + tokio::io::copy_bidirectional(&mut sock, &mut dst_tcp) + .await + .unwrap(); + }; + tokio::spawn(fut); + } + }); + + prx_addr +} + +async fn create_https_proxy(src_addr: SocketAddr) -> SocketAddr { + let mut tls_config = deno_tls::rustls::server::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert( + vec![EXAMPLE_CRT.into()], + webpki::types::PrivateKeyDer::try_from(EXAMPLE_KEY).unwrap(), + ) + .unwrap(); + // Set ALPN, to check our proxy connector. But we shouldn't receive anything. + tls_config.alpn_protocols.push("h2".into()); + tls_config.alpn_protocols.push("http/1.1".into()); + let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::from(tls_config)); + let prx_tcp = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let prx_addr = prx_tcp.local_addr().unwrap(); + + tokio::spawn(async move { + while let Ok((sock, _)) = prx_tcp.accept().await { + let mut sock = tls_acceptor.accept(sock).await.unwrap(); + assert_eq!(sock.get_ref().1.alpn_protocol(), None); + + let fut = async move { + let mut buf = [0u8; 4096]; + let _n = sock.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..7], b"CONNECT"); + let mut dst_tcp = + tokio::net::TcpStream::connect(src_addr).await.unwrap(); + sock.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); + tokio::io::copy_bidirectional(&mut sock, &mut dst_tcp) + .await + .unwrap(); + }; + tokio::spawn(fut); + } + }); + + prx_addr +}