fix(ext/flash): graceful server startup/shutdown (#16383)

Fixes https://github.com/denoland/deno/issues/16267

Co-authored-by: Yusuke Tanaka <yusuktan@maguro.dev>
This commit is contained in:
Divy Srivastava 2022-11-11 05:41:52 -08:00 committed by GitHub
parent 5be8c96ae8
commit ff92febb38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 224 additions and 116 deletions

View file

@ -1167,3 +1167,68 @@ fn run_watch_dynamic_imports() {
check_alive_then_kill(child);
}
// https://github.com/denoland/deno/issues/16267
#[test]
fn run_watch_flash() {
let filename = "watch_flash.js";
let t = TempDir::new();
let file_to_watch = t.path().join(filename);
write(
&file_to_watch,
r#"
console.log("Starting flash server...");
Deno.serve({
onListen() {
console.error("First server is listening");
},
handler: () => {},
port: 4601,
});
"#,
)
.unwrap();
let mut child = util::deno_cmd()
.current_dir(t.path())
.arg("run")
.arg("--watch")
.arg("--unstable")
.arg("--allow-net")
.arg("-L")
.arg("debug")
.arg(&file_to_watch)
.env("NO_COLOR", "1")
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.unwrap();
let (mut stdout_lines, mut stderr_lines) = child_lines(&mut child);
wait_contains("Starting flash server...", &mut stdout_lines);
wait_for(
|m| m.contains("Watching paths") && m.contains(filename),
&mut stderr_lines,
);
write(
&file_to_watch,
r#"
console.log("Restarting flash server...");
Deno.serve({
onListen() {
console.error("Second server is listening");
},
handler: () => {},
port: 4601,
});
"#,
)
.unwrap();
wait_contains("File change detected! Restarting!", &mut stderr_lines);
wait_contains("Restarting flash server...", &mut stdout_lines);
wait_contains("Second server is listening", &mut stderr_lines);
check_alive_then_kill(child);
}

View file

@ -70,6 +70,7 @@ Deno.test(async function httpServerRejectsOnAddrInUse() {
onError: createOnErrorCb(ac),
});
await listeningPromise;
assertRejects(
() =>
Deno.serve({

View file

@ -188,8 +188,8 @@
return str;
}
function prepareFastCalls() {
return core.ops.op_flash_make_request();
function prepareFastCalls(serverId) {
return core.ops.op_flash_make_request(serverId);
}
function hostnameForDisplay(hostname) {
@ -482,15 +482,11 @@
const serverId = opFn(listenOpts);
const serverPromise = core.opAsync("op_flash_drive_server", serverId);
PromisePrototypeCatch(
PromisePrototypeThen(
core.opAsync("op_flash_wait_for_listening", serverId),
(port) => {
onListen({ hostname: listenOpts.hostname, port });
},
),
() => {},
const listenPromise = PromisePrototypeThen(
core.opAsync("op_flash_wait_for_listening", serverId),
(port) => {
onListen({ hostname: listenOpts.hostname, port });
},
);
const finishedPromise = PromisePrototypeCatch(serverPromise, () => {});
@ -506,7 +502,7 @@
return;
}
server.closed = true;
await core.opAsync("op_flash_close_server", serverId);
core.ops.op_flash_close_server(serverId);
await server.finished;
},
async serve() {
@ -618,7 +614,7 @@
signal?.addEventListener("abort", () => {
clearInterval(dateInterval);
PromisePrototypeThen(server.close(), () => {}, () => {});
server.close();
}, {
once: true,
});
@ -633,7 +629,7 @@
);
}
const fastOp = prepareFastCalls();
const fastOp = prepareFastCalls(serverId);
let nextRequestSync = () => fastOp.nextRequest();
let getMethodSync = (token) => fastOp.getMethod(token);
let respondFast = (token, response, shutdown) =>
@ -653,8 +649,8 @@
}
await SafePromiseAll([
listenPromise,
PromisePrototypeCatch(server.serve(), console.error),
serverPromise,
]);
};
}

View file

@ -35,6 +35,7 @@ use mio::Events;
use mio::Interest;
use mio::Poll;
use mio::Token;
use mio::Waker;
use serde::Deserialize;
use serde::Serialize;
use socket2::Socket;
@ -55,7 +56,6 @@ use std::rc::Rc;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Context;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
@ -76,15 +76,24 @@ pub struct FlashContext {
pub servers: HashMap<u32, ServerContext>,
}
impl Drop for FlashContext {
fn drop(&mut self) {
// Signal each server instance to shutdown.
for (_, server) in self.servers.drain() {
let _ = server.waker.wake();
}
}
}
pub struct ServerContext {
_addr: SocketAddr,
tx: mpsc::Sender<Request>,
rx: mpsc::Receiver<Request>,
rx: Option<mpsc::Receiver<Request>>,
requests: HashMap<u32, Request>,
next_token: u32,
listening_rx: Option<mpsc::Receiver<u16>>,
close_tx: mpsc::Sender<()>,
listening_rx: Option<mpsc::Receiver<Result<u16, std::io::Error>>>,
cancel_handle: Rc<CancelHandle>,
waker: Arc<Waker>,
}
#[derive(Debug, Eq, PartialEq)]
@ -102,7 +111,10 @@ fn op_flash_respond(
shutdown: bool,
) -> u32 {
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
let ctx = match flash_ctx.servers.get_mut(&server_id) {
Some(ctx) => ctx,
None => return 0,
};
flash_respond(ctx, token, shutdown, &response)
}
@ -120,7 +132,10 @@ async fn op_flash_respond_async(
let sock = {
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
let ctx = match flash_ctx.servers.get_mut(&server_id) {
Some(ctx) => ctx,
None => return Ok(()),
};
match shutdown {
true => {
@ -384,15 +399,30 @@ fn op_flash_method(state: &mut OpState, server_id: u32, token: u32) -> u32 {
}
#[op]
async fn op_flash_close_server(state: Rc<RefCell<OpState>>, server_id: u32) {
let close_tx = {
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
ctx.cancel_handle.cancel();
ctx.close_tx.clone()
fn op_flash_drive_server(
state: &mut OpState,
server_id: u32,
) -> Result<impl Future<Output = Result<(), AnyError>> + 'static, AnyError> {
let join_handle = {
let flash_ctx = state.borrow_mut::<FlashContext>();
flash_ctx
.join_handles
.remove(&server_id)
.ok_or_else(|| type_error("server not found"))?
};
let _ = close_tx.send(()).await;
Ok(async move {
join_handle
.await
.map_err(|_| type_error("server join error"))??;
Ok(())
})
}
#[op]
fn op_flash_close_server(state: &mut OpState, server_id: u32) {
let flash_ctx = state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.remove(&server_id).unwrap();
let _ = ctx.waker.wake();
}
#[op]
@ -419,7 +449,7 @@ fn op_flash_path(
fn next_request_sync(ctx: &mut ServerContext) -> u32 {
let offset = ctx.next_token;
while let Ok(token) = ctx.rx.try_recv() {
while let Ok(token) = ctx.rx.as_mut().unwrap().try_recv() {
ctx.requests.insert(ctx.next_token, token);
ctx.next_token += 1;
}
@ -482,6 +512,7 @@ unsafe fn op_flash_get_method_fast(
fn op_flash_make_request<'scope>(
scope: &mut v8::HandleScope<'scope>,
state: &mut OpState,
server_id: u32,
) -> serde_v8::Value<'scope> {
let object_template = v8::ObjectTemplate::new(scope);
assert!(object_template
@ -489,7 +520,7 @@ fn op_flash_make_request<'scope>(
let obj = object_template.new_instance(scope).unwrap();
let ctx = {
let flash_ctx = state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&0).unwrap();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
ctx as *mut ServerContext
};
obj.set_aligned_pointer_in_internal_field(V8_WRAPPER_OBJECT_INDEX, ctx as _);
@ -705,7 +736,10 @@ async fn op_flash_read_body(
{
let op_state = &mut state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
flash_ctx.servers.get_mut(&server_id).unwrap() as *mut ServerContext
match flash_ctx.servers.get_mut(&server_id) {
Some(ctx) => ctx as *mut ServerContext,
None => return 0,
}
}
.as_mut()
.unwrap()
@ -807,41 +841,40 @@ pub struct ListenOpts {
reuseport: bool,
}
const SERVER_TOKEN: Token = Token(0);
// Token reserved for the thread close signal.
const WAKER_TOKEN: Token = Token(1);
#[allow(clippy::too_many_arguments)]
fn run_server(
tx: mpsc::Sender<Request>,
listening_tx: mpsc::Sender<u16>,
mut close_rx: mpsc::Receiver<()>,
listening_tx: mpsc::Sender<Result<u16, std::io::Error>>,
addr: SocketAddr,
maybe_cert: Option<String>,
maybe_key: Option<String>,
reuseport: bool,
mut poll: Poll,
// We put a waker as an unused argument here as it needs to be alive both in
// the flash thread and in the main thread (otherwise the notification would
// not be caught by the event loop on Linux).
// See the comment in mio's example:
// https://docs.rs/mio/0.8.4/x86_64-unknown-linux-gnu/mio/struct.Waker.html#examples
_waker: Arc<Waker>,
) -> Result<(), AnyError> {
let domain = if addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
let mut listener = match listen(addr, reuseport) {
Ok(listener) => listener,
Err(e) => {
listening_tx.blocking_send(Err(e)).unwrap();
return Err(generic_error(
"failed to start listening on the specified address",
));
}
};
let socket = Socket::new(domain, socket2::Type::STREAM, None)?;
#[cfg(not(windows))]
socket.set_reuse_address(true)?;
if reuseport {
#[cfg(target_os = "linux")]
socket.set_reuse_port(true)?;
}
let socket_addr = socket2::SockAddr::from(addr);
socket.bind(&socket_addr)?;
socket.listen(128)?;
socket.set_nonblocking(true)?;
let std_listener: std::net::TcpListener = socket.into();
let mut listener = TcpListener::from_std(std_listener);
let mut poll = Poll::new()?;
let token = Token(0);
// Register server.
poll
.registry()
.register(&mut listener, token, Interest::READABLE)
.register(&mut listener, SERVER_TOKEN, Interest::READABLE)
.unwrap();
let tls_context: Option<Arc<rustls::ServerConfig>> = {
@ -863,30 +896,24 @@ fn run_server(
};
listening_tx
.blocking_send(listener.local_addr().unwrap().port())
.blocking_send(Ok(listener.local_addr().unwrap().port()))
.unwrap();
let mut sockets = HashMap::with_capacity(1000);
let mut counter: usize = 1;
let mut counter: usize = 2;
let mut events = Events::with_capacity(1024);
'outer: loop {
let result = close_rx.try_recv();
if result.is_ok() {
break 'outer;
}
// FIXME(bartlomieju): how does Tokio handle it? I just put random 100ms
// timeout here to handle close signal.
match poll.poll(&mut events, Some(Duration::from_millis(100))) {
match poll.poll(&mut events, None) {
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => panic!("{}", e),
Ok(()) => (),
}
'events: for event in &events {
if close_rx.try_recv().is_ok() {
break 'outer;
}
let token = event.token();
match token {
Token(0) => loop {
WAKER_TOKEN => {
break 'outer;
}
SERVER_TOKEN => loop {
match listener.accept() {
Ok((mut socket, _)) => {
counter += 1;
@ -1149,6 +1176,33 @@ fn run_server(
Ok(())
}
#[inline]
fn listen(
addr: SocketAddr,
reuseport: bool,
) -> Result<TcpListener, std::io::Error> {
let domain = if addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
};
let socket = Socket::new(domain, socket2::Type::STREAM, None)?;
#[cfg(not(windows))]
socket.set_reuse_address(true)?;
if reuseport {
#[cfg(target_os = "linux")]
socket.set_reuse_port(true)?;
}
let socket_addr = socket2::SockAddr::from(addr);
socket.bind(&socket_addr)?;
socket.listen(128)?;
socket.set_nonblocking(true)?;
let std_listener: std::net::TcpListener = socket.into();
Ok(TcpListener::from_std(std_listener))
}
fn make_addr_port_pair(hostname: &str, port: u16) -> (&str, u16) {
// Default to localhost if given just the port. Example: ":80"
if hostname.is_empty() {
@ -1186,17 +1240,19 @@ where
.next()
.ok_or_else(|| generic_error("No resolved address found"))?;
let (tx, rx) = mpsc::channel(100);
let (close_tx, close_rx) = mpsc::channel(1);
let (listening_tx, listening_rx) = mpsc::channel(1);
let poll = Poll::new()?;
let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN).unwrap());
let ctx = ServerContext {
_addr: addr,
tx,
rx,
rx: Some(rx),
requests: HashMap::with_capacity(1000),
next_token: 0,
close_tx,
listening_rx: Some(listening_rx),
cancel_handle: CancelHandle::new_rc(),
waker: waker.clone(),
};
let tx = ctx.tx.clone();
let maybe_cert = opts.cert;
@ -1206,11 +1262,12 @@ where
run_server(
tx,
listening_tx,
close_rx,
addr,
maybe_cert,
maybe_key,
reuseport,
poll,
waker,
)
});
let flash_ctx = state.borrow_mut::<FlashContext>();
@ -1245,45 +1302,26 @@ where
}
#[op]
fn op_flash_wait_for_listening(
state: &mut OpState,
async fn op_flash_wait_for_listening(
state: Rc<RefCell<OpState>>,
server_id: u32,
) -> Result<impl Future<Output = Result<u16, AnyError>> + 'static, AnyError> {
) -> Result<u16, AnyError> {
let mut listening_rx = {
let flash_ctx = state.borrow_mut::<FlashContext>();
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let server_ctx = flash_ctx
.servers
.get_mut(&server_id)
.ok_or_else(|| type_error("server not found"))?;
server_ctx.listening_rx.take().unwrap()
};
Ok(async move {
if let Some(port) = listening_rx.recv().await {
Ok(port)
} else {
Err(generic_error("This error will be discarded"))
}
})
}
#[op]
fn op_flash_drive_server(
state: &mut OpState,
server_id: u32,
) -> Result<impl Future<Output = Result<(), AnyError>> + 'static, AnyError> {
let join_handle = {
let flash_ctx = state.borrow_mut::<FlashContext>();
flash_ctx
.join_handles
.remove(&server_id)
.ok_or_else(|| type_error("server not found"))?
};
Ok(async move {
join_handle
.await
.map_err(|_| type_error("server join error"))??;
Ok(())
})
match listening_rx.recv().await {
Some(Ok(port)) => Ok(port),
Some(Err(e)) => Err(e.into()),
_ => Err(generic_error(
"unknown error occurred while waiting for listening",
)),
}
}
// Asychronous version of op_flash_next. This can be a bottleneck under
@ -1291,26 +1329,34 @@ fn op_flash_drive_server(
// requests i.e `op_flash_next() == 0`.
#[op]
async fn op_flash_next_async(
op_state: Rc<RefCell<OpState>>,
state: Rc<RefCell<OpState>>,
server_id: u32,
) -> u32 {
let ctx = {
let mut op_state = op_state.borrow_mut();
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
let cancel_handle = ctx.cancel_handle.clone();
let mut rx = ctx.rx.take().unwrap();
// We need to drop the borrow before await point.
drop(op_state);
if let Ok(Some(req)) = rx.recv().or_cancel(&cancel_handle).await {
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
ctx as *mut ServerContext
};
// SAFETY: we cannot hold op_state borrow across the await point. The JS caller
// is responsible for ensuring this is not called concurrently.
let ctx = unsafe { &mut *ctx };
let cancel_handle = &ctx.cancel_handle;
if let Ok(Some(req)) = ctx.rx.recv().or_cancel(cancel_handle).await {
ctx.requests.insert(ctx.next_token, req);
ctx.next_token += 1;
// Set the rx back.
ctx.rx = Some(rx);
return 1;
}
// Set the rx back.
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
if let Some(ctx) = flash_ctx.servers.get_mut(&server_id) {
ctx.rx = Some(rx);
}
0
}
@ -1478,11 +1524,11 @@ pub fn init<P: FlashPermissions + 'static>(unstable: bool) -> Extension {
op_flash_next_async::decl(),
op_flash_read_body::decl(),
op_flash_upgrade_websocket::decl(),
op_flash_drive_server::decl(),
op_flash_wait_for_listening::decl(),
op_flash_first_packet::decl(),
op_flash_has_body_stream::decl(),
op_flash_close_server::decl(),
op_flash_drive_server::decl(),
op_flash_make_request::decl(),
op_flash_write_resource::decl(),
])