From 359fcf348fb04ac7d66992138e71b0818a95056d Mon Sep 17 00:00:00 2001 From: asynts Date: Wed, 2 Sep 2020 17:09:00 +0200 Subject: [PATCH] AK: Add Buffered which wraps a stream, adding input buffering. --- AK/Buffered.h | 187 ++++++++++++++++++++++++++++++++++++++++++++++++++ AK/Stream.h | 16 ++--- 2 files changed, 195 insertions(+), 8 deletions(-) create mode 100644 AK/Buffered.h diff --git a/AK/Buffered.h b/AK/Buffered.h new file mode 100644 index 0000000000..fc74fe1af0 --- /dev/null +++ b/AK/Buffered.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020, the SerenityOS developers. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#pragma once + +#include + +namespace AK { + +// FIXME: Implement Buffered for DuplexStream. + +template +class Buffered; + +template +class Buffered::value>::Type> final : public InputStream { +public: + template + explicit Buffered(Parameters&&... parameters) + : m_stream(forward(parameters)...) + { + } + + bool has_recoverable_error() const override { return m_stream.has_recoverable_error(); } + bool has_fatal_error() const override { return m_stream.has_fatal_error(); } + bool has_any_error() const override { return m_stream.has_any_error(); } + + bool handle_recoverable_error() override { return m_stream.handle_recoverable_error(); } + bool handle_fatal_error() override { return m_stream.handle_fatal_error(); } + bool handle_any_error() override { return m_stream.handle_any_error(); } + + void set_recoverable_error() const override { return m_stream.set_recoverable_error(); } + void set_fatal_error() const override { return m_stream.set_fatal_error(); } + + size_t read(Bytes bytes) override + { + auto nread = buffer().trim(m_buffer_remaining).copy_trimmed_to(bytes); + + m_buffer_remaining -= nread; + buffer().slice(nread, m_buffer_remaining).copy_to(buffer()); + + if (nread < bytes.size()) { + m_buffer_remaining = m_stream.read(buffer()); + + if (m_buffer_remaining == 0) + return nread; + + nread += read(bytes.slice(nread)); + } + + return nread; + } + + virtual bool read_or_error(Bytes bytes) override + { + if (read(bytes) < bytes.size()) { + set_fatal_error(); + return false; + } + + return true; + } + + virtual bool eof() const + { + if (m_buffer_remaining > 0) + return false; + + m_buffer_remaining = m_stream.read(buffer()); + + return m_buffer_remaining == 0; + } + + virtual bool discard_or_error(size_t count) override + { + size_t ndiscarded = 0; + while (ndiscarded < count) { + u8 dummy[Size]; + + if (!read_or_error({ dummy, min(Size, count - ndiscarded) })) + return false; + + ndiscarded += min(Size, count - ndiscarded); + } + + return true; + } + +private: + Bytes buffer() const { return { m_buffer, Size }; } + + mutable StreamType m_stream; + mutable u8 m_buffer[Size]; + mutable size_t m_buffer_remaining { 0 }; +}; + +template +class Buffered::value>::Type> final : public OutputStream { +public: + template + explicit Buffered(Parameters&&... parameters) + : m_stream(forward(parameters)...) + { + } + + ~Buffered() + { + flush(); + } + + bool has_recoverable_error() const override { return m_stream.has_recoverable_error(); } + bool has_fatal_error() const override { return m_stream.has_fatal_error(); } + bool has_any_error() const override { return m_stream.has_any_error(); } + + bool handle_recoverable_error() override { return m_stream.handle_recoverable_error(); } + bool handle_fatal_error() override { return m_stream.handle_fatal_error(); } + bool handle_any_error() override { return m_stream.handle_any_error(); } + + void set_recoverable_error() const override { return m_stream.set_recoverable_error(); } + void set_fatal_error() const override { return m_stream.set_fatal_error(); } + + size_t write(ReadonlyBytes bytes) override + { + if (has_any_error()) + return 0; + + auto nwritten = bytes.copy_trimmed_to(buffer().slice(m_buffered)); + m_buffered += nwritten; + + if (m_buffered == Size) { + flush(); + + if (bytes.size() - nwritten >= Size) + nwritten += m_stream.write_or_error(bytes); + + nwritten += write(bytes.slice(nwritten)); + } + + return nwritten; + } + + bool write_or_error(ReadonlyBytes bytes) override + { + write(bytes); + return true; + } + + void flush() + { + m_stream.write_or_error({ m_buffer, m_buffered }); + m_buffered = 0; + } + +private: + Bytes buffer() { return { m_buffer, Size }; } + + StreamType m_stream; + u8 m_buffer[Size]; + size_t m_buffered { 0 }; +}; + +} + +using AK::Buffered; diff --git a/AK/Stream.h b/AK/Stream.h index c8fd053822..4c3dc389c8 100644 --- a/AK/Stream.h +++ b/AK/Stream.h @@ -39,17 +39,17 @@ class Stream { public: virtual ~Stream() { ASSERT(!has_any_error()); } - bool has_recoverable_error() const { return m_recoverable_error; } - bool has_fatal_error() const { return m_fatal_error; } - bool has_any_error() const { return has_recoverable_error() || has_fatal_error(); } + virtual bool has_recoverable_error() const { return m_recoverable_error; } + virtual bool has_fatal_error() const { return m_fatal_error; } + virtual bool has_any_error() const { return has_recoverable_error() || has_fatal_error(); } - bool handle_recoverable_error() + virtual bool handle_recoverable_error() { ASSERT(!has_fatal_error()); return exchange(m_recoverable_error, false); } - bool handle_fatal_error() { return exchange(m_fatal_error, false); } - bool handle_any_error() + virtual bool handle_fatal_error() { return exchange(m_fatal_error, false); } + virtual bool handle_any_error() { if (has_any_error()) { m_recoverable_error = false; @@ -61,8 +61,8 @@ public: return false; } - void set_recoverable_error() const { m_recoverable_error = true; } - void set_fatal_error() const { m_fatal_error = true; } + virtual void set_recoverable_error() const { m_recoverable_error = true; } + virtual void set_fatal_error() const { m_fatal_error = true; } private: mutable bool m_recoverable_error { false };