LibCrypto+LibTLS: Avoid unaligned reads and writes

This adds an `AK::ByteReader` to help with that so we don't duplicate
the logic all over the place.
No more `*(const u16*)` and `*(const u32*)` for anyone.
This should help a little with #7060.
This commit is contained in:
Ali Mohammad Pur 2021-05-14 09:22:56 +04:30 committed by Linus Groh
parent bfd4c7a16c
commit df515e1d85
6 changed files with 88 additions and 17 deletions

69
AK/ByteReader.h Normal file
View file

@ -0,0 +1,69 @@
/*
* Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include <AK/Types.h>
namespace AK {
struct ByteReader {
static void store(u8* address, u16 value)
{
union {
u16 _16;
u8 _8[2];
} const v { ._16 = value };
__builtin_memcpy(address, v._8, 2);
}
static void store(u8* address, u32 value)
{
union {
u32 _32;
u8 _8[4];
} const v { ._32 = value };
__builtin_memcpy(address, v._8, 4);
}
static void load(const u8* address, u16& value)
{
union {
u16 _16;
u8 _8[2];
} v { ._16 = 0 };
__builtin_memcpy(&v._8, address, 2);
value = v._16;
}
static void load(const u8* address, u32& value)
{
union {
u32 _32;
u8 _8[4];
} v { ._32 = 0 };
__builtin_memcpy(&v._8, address, 4);
value = v._32;
}
static u16 load16(const u8* address)
{
u16 value;
load(address, value);
return value;
}
static u32 load32(const u8* address)
{
u32 value;
load(address, value);
return value;
}
};
}
using AK::ByteReader;

View file

@ -4,6 +4,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/ByteReader.h>
#include <AK/Debug.h>
#include <AK/MemoryStream.h>
#include <AK/Types.h>
@ -15,14 +16,13 @@ namespace {
static u32 to_u32(const u8* b)
{
return AK::convert_between_host_and_big_endian(*(const u32*)b);
return AK::convert_between_host_and_big_endian(ByteReader::load32(b));
}
static void to_u8s(u8* b, const u32* w)
{
for (auto i = 0; i < 4; ++i) {
auto& e = *((u32*)(b + i * 4));
e = AK::convert_between_host_and_big_endian(w[i]);
ByteReader::store(b + i * 4, AK::convert_between_host_and_big_endian(w[i]));
}
}

View file

@ -53,7 +53,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
dbgln("not enough data for version");
return (i8)Error::NeedMoreData;
}
auto version = (Version)AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
auto version = static_cast<Version>(AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res))));
res += 2;
if (!supports_version(version))
@ -84,7 +84,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
dbgln("not enough data for cipher suite listing");
return (i8)Error::NeedMoreData;
}
auto cipher = (CipherSuite)AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
auto cipher = static_cast<CipherSuite>(AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res))));
res += 2;
if (!supports_cipher(cipher)) {
m_context.cipher = CipherSuite::Invalid;
@ -113,14 +113,14 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
// Presence of extensions is determined by availability of bytes after compression_method
if (buffer.size() - res >= 2) {
auto extensions_bytes_total = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res += 2));
auto extensions_bytes_total = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res += 2)));
dbgln_if(TLS_DEBUG, "Extensions bytes total: {}", extensions_bytes_total);
}
while (buffer.size() - res >= 4) {
auto extension_type = (HandshakeExtension)AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
auto extension_type = (HandshakeExtension)AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res)));
res += 2;
u16 extension_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
u16 extension_length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res)));
res += 2;
dbgln_if(TLS_DEBUG, "Extension {} with length {}", (u16)extension_type, extension_length);
@ -134,14 +134,14 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
// ServerNameList total size
if (buffer.size() - res < 2)
return (i8)Error::NeedMoreData;
auto sni_name_list_bytes = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res += 2));
auto sni_name_list_bytes = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res += 2)));
dbgln_if(TLS_DEBUG, "SNI: expecting ServerNameList of {} bytes", sni_name_list_bytes);
// Exactly one ServerName should be present
if (buffer.size() - res < 3)
return (i8)Error::NeedMoreData;
auto sni_name_type = (NameType)(*(const u8*)buffer.offset_pointer(res++));
auto sni_name_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res += 2));
auto sni_name_length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res += 2)));
if (sni_name_type != NameType::HostName)
return (i8)Error::NotUnderstood;
@ -158,7 +158,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
}
} else if (extension_type == HandshakeExtension::ApplicationLayerProtocolNegotiation && m_context.alpn.size()) {
if (buffer.size() - res > 2) {
auto alpn_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
auto alpn_length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res)));
if (alpn_length && alpn_length <= extension_length - 2) {
const u8* alpn = buffer.offset_pointer(res + 2);
size_t alpn_position = 0;

View file

@ -34,7 +34,7 @@ void TLSv12::write_packet(ByteBuffer& packet)
void TLSv12::update_packet(ByteBuffer& packet)
{
u32 header_size = 5;
*(u16*)packet.offset_pointer(3) = AK::convert_between_host_and_network_endian((u16)(packet.size() - header_size));
ByteReader::store(packet.offset_pointer(3), AK::convert_between_host_and_network_endian((u16)(packet.size() - header_size)));
if (packet[0] != (u8)MessageType::ChangeCipher) {
if (packet[0] == (u8)MessageType::Handshake && packet.size() > header_size) {
@ -159,7 +159,7 @@ void TLSv12::update_packet(ByteBuffer& packet)
// store the correct ciphertext length into the packet
u16 ct_length = (u16)ct.size() - header_size;
*(u16*)ct.offset_pointer(header_size - 2) = AK::convert_between_host_and_network_endian(ct_length);
ByteReader::store(ct.offset_pointer(header_size - 2), AK::convert_between_host_and_network_endian(ct_length));
// replace the packet with the ciphertext
packet = ct;
@ -222,13 +222,14 @@ ssize_t TLSv12::handle_message(ReadonlyBytes buffer)
// FIXME: Read the version and verify it
if constexpr (TLS_DEBUG) {
auto version = (Version) * (const u16*)buffer.offset_pointer(buffer_position);
auto version = ByteReader::load16(buffer.offset_pointer(buffer_position));
dbgln("type={}, version={}", (u8)type, (u16)version);
}
buffer_position += 2;
auto length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(buffer_position));
auto length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(buffer_position)));
dbgln_if(TLS_DEBUG, "record length: {} at offset: {}", length, buffer_position);
buffer_position += 2;

View file

@ -7,6 +7,7 @@
#pragma once
#include <AK/ByteBuffer.h>
#include <AK/ByteReader.h>
#include <AK/Endian.h>
#include <AK/Types.h>
@ -38,7 +39,7 @@ public:
m_packet_data = ByteBuffer::create_uninitialized(size_hint + 16);
m_current_length = 5;
m_packet_data[0] = (u8)type;
*(u16*)m_packet_data.offset_pointer(1) = AK::convert_between_host_and_network_endian((u16)version);
ByteReader::store(m_packet_data.offset_pointer(1), AK::convert_between_host_and_network_endian((u16)version));
}
inline void append(u16 value)

View file

@ -601,7 +601,7 @@ void TLSv12::consume(ReadonlyBytes record)
dbgln_if(TLS_DEBUG, "message buffer length {}", buffer_length);
while (buffer_length >= 5) {
auto length = AK::convert_between_host_and_network_endian(*(u16*)m_context.message_buffer.offset_pointer(index + size_offset)) + header_size;
auto length = AK::convert_between_host_and_network_endian(ByteReader::load16(m_context.message_buffer.offset_pointer(index + size_offset))) + header_size;
if (length > buffer_length) {
dbgln_if(TLS_DEBUG, "Need more data: {} > {}", length, buffer_length);
break;