deno/ext/websocket/stream.rs

115 lines
3.3 KiB
Rust

// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
use bytes::Buf;
use bytes::Bytes;
use deno_net::raw::NetworkStream;
use hyper::upgrade::Upgraded;
use std::pin::Pin;
use std::task::Poll;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
// TODO(bartlomieju): remove this
pub(crate) enum WsStreamKind {
Upgraded(Upgraded),
Network(NetworkStream),
}
pub(crate) struct WebSocketStream {
stream: WsStreamKind,
pre: Option<Bytes>,
}
impl WebSocketStream {
pub fn new(stream: WsStreamKind, buffer: Option<Bytes>) -> Self {
Self {
stream,
pre: buffer,
}
}
}
impl AsyncRead for WebSocketStream {
// From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}
return Poll::Ready(Ok(()));
}
}
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for WebSocketStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx),
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
fn is_write_vectored(&self) -> bool {
match &self.stream {
WsStreamKind::Network(stream) => stream.is_write_vectored(),
WsStreamKind::Upgraded(stream) => stream.is_write_vectored(),
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs)
}
WsStreamKind::Upgraded(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs)
}
}
}
}