deno/ext/websocket/stream.rs
2024-01-01 19:58:21 +00:00

179 lines
5.5 KiB
Rust

// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use bytes::Buf;
use bytes::Bytes;
use deno_net::raw::NetworkStream;
use h2::RecvStream;
use h2::SendStream;
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::io::ErrorKind;
use std::pin::Pin;
use std::task::ready;
use std::task::Poll;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
// TODO(bartlomieju): remove this
pub(crate) enum WsStreamKind {
Upgraded(TokioIo<Upgraded>),
Network(NetworkStream),
H2(SendStream<Bytes>, RecvStream),
}
pub(crate) struct WebSocketStream {
stream: WsStreamKind,
pre: Option<Bytes>,
}
impl WebSocketStream {
pub fn new(stream: WsStreamKind, buffer: Option<Bytes>) -> Self {
Self {
stream,
pre: buffer,
}
}
}
impl AsyncRead for WebSocketStream {
// From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}
return Poll::Ready(Ok(()));
}
}
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf),
WsStreamKind::H2(_, recv) => {
let data = ready!(recv.poll_data(cx));
let Some(data) = data else {
// EOF
return Poll::Ready(Ok(()));
};
let mut data = data.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
})?;
recv.flow_control().release_capacity(data.len()).unwrap();
// This looks like the prefix code above -- can we share this?
let copy_len = std::cmp::min(data.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&data[..copy_len]);
data.advance(copy_len);
// Put back what's left
if !data.is_empty() {
self.pre = Some(data);
}
Poll::Ready(Ok(()))
}
}
}
}
impl AsyncWrite for WebSocketStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf),
WsStreamKind::H2(send, _) => {
// Zero-length write succeeds
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
send.reserve_capacity(buf.len());
let res = ready!(send.poll_capacity(cx));
// TODO(mmastrac): the documentation is not entirely clear what to do here, so we'll continue
_ = res;
// We'll try to send whatever we have capacity for
let size = std::cmp::min(buf.len(), send.capacity());
assert!(size > 0);
let buf: Bytes = Bytes::copy_from_slice(&buf[0..size]);
let len = buf.len();
// TODO(mmastrac): surface the h2 error?
let res = send
.send_data(buf, false)
.map_err(|_| std::io::Error::from(ErrorKind::Other));
Poll::Ready(res.map(|_| len))
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx),
WsStreamKind::H2(..) => Poll::Ready(Ok(())),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx),
WsStreamKind::H2(send, _) => {
// TODO(mmastrac): surface the h2 error?
let res = send
.send_data(Bytes::new(), false)
.map_err(|_| std::io::Error::from(ErrorKind::Other));
Poll::Ready(res)
}
}
}
fn is_write_vectored(&self) -> bool {
match &self.stream {
WsStreamKind::Network(stream) => stream.is_write_vectored(),
WsStreamKind::Upgraded(stream) => stream.is_write_vectored(),
WsStreamKind::H2(..) => false,
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs)
}
WsStreamKind::Upgraded(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs)
}
WsStreamKind::H2(..) => {
// TODO(mmastrac): this is possibly just too difficult, but we'll never call it
unimplemented!()
}
}
}
}