diff --git a/Tests/LibCompress/TestLzma.cpp b/Tests/LibCompress/TestLzma.cpp index 58e9d5decb..67f2a0d9f2 100644 --- a/Tests/LibCompress/TestLzma.cpp +++ b/Tests/LibCompress/TestLzma.cpp @@ -75,6 +75,26 @@ TEST_CASE(compress_decompress_roundtrip_with_unknown_size) EXPECT_EQ(uncompressed, result.span()); } +TEST_CASE(compress_long_overflow_chain) +{ + // Encoding 0xFF followed by the end-of-stream marker results in a chain of bytes that doesn't fit into 64 bits, + // which breaks naive implementations of "hold back the byte until it no longer changes". + + Array const uncompressed { + 0xFF + }; + + auto stream = MUST(try_make()); + auto compressor = TRY_OR_FAIL(Compress::LzmaCompressor::create_container(MaybeOwned { *stream }, {})); + TRY_OR_FAIL(compressor->write_until_depleted(uncompressed)); + TRY_OR_FAIL(compressor->flush()); + + auto decompressor = TRY_OR_FAIL(Compress::LzmaDecompressor::create_from_container(MaybeOwned { *stream })); + auto result = TRY_OR_FAIL(decompressor->read_until_eof()); + + EXPECT_EQ(uncompressed, result.span()); +} + // The following tests are based on test files from the LZMA specification, which has been placed in the public domain. // LZMA Specification Draft (2015): https://www.7-zip.org/a/lzma-specification.7z diff --git a/Userland/Libraries/LibCompress/Lzma.cpp b/Userland/Libraries/LibCompress/Lzma.cpp index e1dc18c847..97a65e595a 100644 --- a/Userland/Libraries/LibCompress/Lzma.cpp +++ b/Userland/Libraries/LibCompress/Lzma.cpp @@ -249,33 +249,53 @@ ErrorOr LzmaDecompressor::normalize_range_decoder() return {}; } +ErrorOr LzmaCompressor::shift_range_encoder() +{ + if ((m_range_encoder_code >> 32) == 0x01) { + // If there is an overflow, we can finalize the chain we were previously building. + // This includes incrementing both the cached byte and all the 0xFF bytes that we generate. + VERIFY(m_range_encoder_cached_byte != 0xFF); + TRY(m_stream->write_value(m_range_encoder_cached_byte + 1)); + for (size_t i = 0; i < m_range_encoder_ff_chain_length; i++) + TRY(m_stream->write_value(0x00)); + m_range_encoder_ff_chain_length = 0; + m_range_encoder_cached_byte = (m_range_encoder_code >> 24); + } else if ((m_range_encoder_code >> 24) == 0xFF) { + // If the byte to flush is 0xFF, it can potentially propagate an overflow and needs to be added to the chain. + m_range_encoder_ff_chain_length++; + } else { + // If the byte to flush isn't 0xFF, any future overflows will not be propagated beyond this point, + // so we can be sure that the built chain doesn't change anymore. + TRY(m_stream->write_value(m_range_encoder_cached_byte)); + for (size_t i = 0; i < m_range_encoder_ff_chain_length; i++) + TRY(m_stream->write_value(0xFF)); + m_range_encoder_ff_chain_length = 0; + m_range_encoder_cached_byte = (m_range_encoder_code >> 24); + } + + // In all three cases we now recorded the highest byte in some way, so we can shift it away and shift in a null byte as the lowest byte. + m_range_encoder_range <<= 8; + m_range_encoder_code <<= 8; + + // Since we are working with a 64-bit code, we need to limit it to 32 bits artificially. + m_range_encoder_code &= 0xFFFFFFFF; + + return {}; +} + ErrorOr LzmaCompressor::normalize_range_encoder() { u64 const maximum_range_value = m_range_encoder_code + m_range_encoder_range; - // If we hit this, we have the potential to overflow into a byte that we already flushed. - VERIFY((maximum_range_value & ((1ull << m_range_encoder_code_used_bits) - 1)) == maximum_range_value); + // Logically, we should only ever build up an overflow that is smaller than or equal to 0x01. + VERIFY((maximum_range_value >> 32) <= 0x01); constexpr u32 minimum_range_value = 1 << 24; if (m_range_encoder_range >= minimum_range_value) return {}; - u64 const flipped_bits = maximum_range_value ^ m_range_encoder_code; - u64 const size_of_flipped_bits = count_required_bits(flipped_bits); - - // If we can flush a full byte without impacting future bits, do so. - while (m_range_encoder_code_used_bits - 8 >= size_of_flipped_bits) { - u8 const next_byte = (m_range_encoder_code >> (m_range_encoder_code_used_bits - 8)); - m_range_encoder_code -= static_cast(next_byte) << (m_range_encoder_code_used_bits - 8); - m_range_encoder_code_used_bits -= 8; - TRY(m_stream->write_value(next_byte)); - } - - // Now, shift in a fresh null byte from the bottom. - m_range_encoder_range <<= 8; - m_range_encoder_code <<= 8; - m_range_encoder_code_used_bits += 8; + TRY(shift_range_encoder()); VERIFY(m_range_encoder_range >= minimum_range_value); @@ -1212,10 +1232,6 @@ ErrorOr> LzmaCompressor::create_container(MaybeOwn auto header = TRY(LzmaHeader::from_compressor_options(options)); TRY(stream->write_value(header)); - // Note: The reference LZMA implementation has a starting null byte due to how their overflow reservoir is implemented and subsequently wrote it into the specification. - // Therefore, we just have to add it manually. - TRY(stream->write_value(0x00)); - auto compressor = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LzmaCompressor(move(stream), options, move(dictionary), move(literal_probabilities)))); return compressor; @@ -1276,13 +1292,18 @@ ErrorOr LzmaCompressor::flush() if (!m_options.uncompressed_size.has_value()) TRY(encode_normalized_simple_match(end_of_stream_marker, 0)); - while (m_range_encoder_code_used_bits > 0) { - VERIFY(m_range_encoder_code_used_bits >= 8); - u8 const next_byte = (m_range_encoder_code >> (m_range_encoder_code_used_bits - 8)); - m_range_encoder_code -= static_cast(next_byte) << (m_range_encoder_code_used_bits - 8); - m_range_encoder_code_used_bits -= 8; - TRY(m_stream->write_value(next_byte)); - } + // Shifting the range encoder using the normal operation handles any pending overflows. + TRY(shift_range_encoder()); + + // Now, the remaining bytes are the cached byte, the chain of 0xFF, and the upper 3 bytes of the current `code`. + // Incrementing the values does not have to be considered as no overflows are pending. The fourth byte is the + // null byte that we just shifted in, which should not be flushed as it would be extraneous junk data. + TRY(m_stream->write_value(m_range_encoder_cached_byte)); + for (size_t i = 0; i < m_range_encoder_ff_chain_length; i++) + TRY(m_stream->write_value(0xFF)); + TRY(m_stream->write_value(m_range_encoder_code >> 24)); + TRY(m_stream->write_value(m_range_encoder_code >> 16)); + TRY(m_stream->write_value(m_range_encoder_code >> 8)); m_has_flushed_data = true; return {}; diff --git a/Userland/Libraries/LibCompress/Lzma.h b/Userland/Libraries/LibCompress/Lzma.h index 2672d43c4d..e9b811b179 100644 --- a/Userland/Libraries/LibCompress/Lzma.h +++ b/Userland/Libraries/LibCompress/Lzma.h @@ -225,6 +225,7 @@ public: private: LzmaCompressor(MaybeOwned, LzmaCompressorOptions, MaybeOwned, FixedArray literal_probabilities); + ErrorOr shift_range_encoder(); ErrorOr normalize_range_encoder(); ErrorOr encode_direct_bit(u8 value); ErrorOr encode_bit_with_probability(Probability&, u8 value); @@ -253,7 +254,12 @@ private: // Range encoder state. u32 m_range_encoder_range { 0xFFFFFFFF }; u64 m_range_encoder_code { 0 }; - size_t m_range_encoder_code_used_bits { 32 }; + + // Since the range is only 32-bits, we can overflow at most +1 into the next byte beyond the usual 32-bit code. + // Therefore, it is sufficient to store the highest byte (which may still change due to that +1 overflow) and + // the length of the chain of 0xFF bytes that may end up propagating that change. + u8 m_range_encoder_cached_byte { 0x00 }; + size_t m_range_encoder_ff_chain_length { 0 }; }; }