From 4e3b59a4bb6931d79e56599fc39150628ca27e08 Mon Sep 17 00:00:00 2001 From: Tim Ledbetter Date: Thu, 2 Nov 2023 20:04:11 +0000 Subject: [PATCH] LibDNS: Prefer spans over raw pointers when parsing DNS packets This means we don't have to keep track of the pointer and size separately. --- Userland/Libraries/LibDNS/Name.cpp | 12 +++++----- Userland/Libraries/LibDNS/Name.h | 2 +- Userland/Libraries/LibDNS/Packet.cpp | 22 +++++++++---------- Userland/Libraries/LibDNS/Packet.h | 2 +- Userland/Services/LookupServer/DNSServer.cpp | 2 +- .../Services/LookupServer/LookupServer.cpp | 4 ++-- .../Services/LookupServer/MulticastDNS.cpp | 4 ++-- 7 files changed, 23 insertions(+), 25 deletions(-) diff --git a/Userland/Libraries/LibDNS/Name.cpp b/Userland/Libraries/LibDNS/Name.cpp index 4226202531..7e7341baa1 100644 --- a/Userland/Libraries/LibDNS/Name.cpp +++ b/Userland/Libraries/LibDNS/Name.cpp @@ -21,14 +21,14 @@ Name::Name(DeprecatedString const& name) m_name = name; } -Name Name::parse(u8 const* data, size_t& offset, size_t max_offset, size_t recursion_level) +Name Name::parse(ReadonlyBytes data, size_t& offset, size_t recursion_level) { if (recursion_level > 4) return {}; StringBuilder builder; while (true) { - if (offset >= max_offset) + if (offset >= data.size()) return {}; u8 b = data[offset++]; if (b == '\0') { @@ -36,17 +36,17 @@ Name Name::parse(u8 const* data, size_t& offset, size_t max_offset, size_t recur return builder.to_deprecated_string(); } else if ((b & 0xc0) == 0xc0) { // The two bytes tell us the offset when to continue from. - if (offset >= max_offset) + if (offset >= data.size()) return {}; size_t dummy = (b & 0x3f) << 8 | data[offset++]; - auto rest_of_name = parse(data, dummy, max_offset, recursion_level + 1); + auto rest_of_name = parse(data, dummy, recursion_level + 1); builder.append(rest_of_name.as_string()); return builder.to_deprecated_string(); } else { // This is the length of a part. - if (offset + b >= max_offset) + if (offset + b >= data.size()) return {}; - builder.append((char const*)&data[offset], (size_t)b); + builder.append({ data.offset_pointer(offset), b }); builder.append('.'); offset += b; } diff --git a/Userland/Libraries/LibDNS/Name.h b/Userland/Libraries/LibDNS/Name.h index cf84c8df8e..3140b01d2b 100644 --- a/Userland/Libraries/LibDNS/Name.h +++ b/Userland/Libraries/LibDNS/Name.h @@ -17,7 +17,7 @@ public: Name() = default; Name(DeprecatedString const&); - static Name parse(u8 const* data, size_t& offset, size_t max_offset, size_t recursion_level = 0); + static Name parse(ReadonlyBytes data, size_t& offset, size_t recursion_level = 0); size_t serialized_size() const; DeprecatedString const& as_string() const { return m_name; } diff --git a/Userland/Libraries/LibDNS/Packet.cpp b/Userland/Libraries/LibDNS/Packet.cpp index abe5e25cb1..8f43ffb2ba 100644 --- a/Userland/Libraries/LibDNS/Packet.cpp +++ b/Userland/Libraries/LibDNS/Packet.cpp @@ -97,14 +97,14 @@ private: static_assert(sizeof(DNSRecordWithoutName) == 10); -Optional Packet::from_raw_packet(u8 const* raw_data, size_t raw_size) +Optional Packet::from_raw_packet(ReadonlyBytes bytes) { - if (raw_size < sizeof(PacketHeader)) { - dbgln("DNS response not large enough ({} out of {}) to be a DNS packet.", raw_size, sizeof(PacketHeader)); + if (bytes.size() < sizeof(PacketHeader)) { + dbgln("DNS response not large enough ({} out of {}) to be a DNS packet.", bytes.size(), sizeof(PacketHeader)); return {}; } - auto& header = *(PacketHeader const*)(raw_data); + auto const& header = *bit_cast(bytes.data()); dbgln_if(LOOKUPSERVER_DEBUG, "Got packet (ID: {})", header.id()); dbgln_if(LOOKUPSERVER_DEBUG, " Question count: {}", header.question_count()); dbgln_if(LOOKUPSERVER_DEBUG, " Answer count: {}", header.answer_count()); @@ -123,12 +123,12 @@ Optional Packet::from_raw_packet(u8 const* raw_data, size_t raw_size) size_t offset = sizeof(PacketHeader); for (u16 i = 0; i < header.question_count(); i++) { - auto name = Name::parse(raw_data, offset, raw_size); + auto name = Name::parse(bytes, offset); struct RawDNSAnswerQuestion { NetworkOrdered record_type; NetworkOrdered class_code; }; - auto& record_and_class = *(RawDNSAnswerQuestion const*)&raw_data[offset]; + auto const& record_and_class = *bit_cast(bytes.offset_pointer(offset)); u16 class_code = record_and_class.class_code & ~MDNS_WANTS_UNICAST_RESPONSE; bool mdns_wants_unicast_response = record_and_class.class_code & MDNS_WANTS_UNICAST_RESPONSE; packet.m_questions.empend(name, (RecordType)(u16)record_and_class.record_type, (RecordClass)class_code, mdns_wants_unicast_response); @@ -138,18 +138,16 @@ Optional Packet::from_raw_packet(u8 const* raw_data, size_t raw_size) } for (u16 i = 0; i < header.answer_count(); ++i) { - auto name = Name::parse(raw_data, offset, raw_size); - - auto& record = *(DNSRecordWithoutName const*)(&raw_data[offset]); + auto name = Name::parse(bytes, offset); + auto const& record = *bit_cast(bytes.offset_pointer(offset)); + offset += sizeof(DNSRecordWithoutName); DeprecatedString data; - offset += sizeof(DNSRecordWithoutName); - switch ((RecordType)record.type()) { case RecordType::PTR: { size_t dummy_offset = offset; - data = Name::parse(raw_data, dummy_offset, raw_size).as_string(); + data = Name::parse(bytes, dummy_offset).as_string(); break; } case RecordType::CNAME: diff --git a/Userland/Libraries/LibDNS/Packet.h b/Userland/Libraries/LibDNS/Packet.h index 7a1ab8262b..159fd2abc3 100644 --- a/Userland/Libraries/LibDNS/Packet.h +++ b/Userland/Libraries/LibDNS/Packet.h @@ -24,7 +24,7 @@ class Packet { public: Packet() = default; - static Optional from_raw_packet(u8 const*, size_t); + static Optional from_raw_packet(ReadonlyBytes bytes); ErrorOr to_byte_buffer() const; bool is_query() const { return !m_query_or_response; } diff --git a/Userland/Services/LookupServer/DNSServer.cpp b/Userland/Services/LookupServer/DNSServer.cpp index 2ba45cf00e..79e244d08e 100644 --- a/Userland/Services/LookupServer/DNSServer.cpp +++ b/Userland/Services/LookupServer/DNSServer.cpp @@ -29,7 +29,7 @@ ErrorOr DNSServer::handle_client() { sockaddr_in client_address; auto buffer = TRY(receive(1024, client_address)); - auto optional_request = Packet::from_raw_packet(buffer.data(), buffer.size()); + auto optional_request = Packet::from_raw_packet(buffer); if (!optional_request.has_value()) { dbgln("Got an invalid DNS packet"); return {}; diff --git a/Userland/Services/LookupServer/LookupServer.cpp b/Userland/Services/LookupServer/LookupServer.cpp index 53e807a628..eb96ec986e 100644 --- a/Userland/Services/LookupServer/LookupServer.cpp +++ b/Userland/Services/LookupServer/LookupServer.cpp @@ -263,13 +263,13 @@ ErrorOr> LookupServer::lookup(Name const& name, DeprecatedString TRY(udp_socket->write_until_depleted(buffer)); u8 response_buffer[4096]; - int nrecv = TRY(udp_socket->read_some({ response_buffer, sizeof(response_buffer) })).size(); + auto nrecv = TRY(udp_socket->read_some({ response_buffer, sizeof(response_buffer) })).size(); if (udp_socket->is_eof()) return Vector {}; did_get_response = true; - auto o_response = Packet::from_raw_packet(response_buffer, nrecv); + auto o_response = Packet::from_raw_packet({ response_buffer, nrecv }); if (!o_response.has_value()) return Vector {}; diff --git a/Userland/Services/LookupServer/MulticastDNS.cpp b/Userland/Services/LookupServer/MulticastDNS.cpp index 88a0d0bf7e..3c4b810993 100644 --- a/Userland/Services/LookupServer/MulticastDNS.cpp +++ b/Userland/Services/LookupServer/MulticastDNS.cpp @@ -56,7 +56,7 @@ MulticastDNS::MulticastDNS(Core::EventReceiver* parent) ErrorOr MulticastDNS::handle_packet() { auto buffer = TRY(receive(1024)); - auto optional_packet = Packet::from_raw_packet(buffer.data(), buffer.size()); + auto optional_packet = Packet::from_raw_packet(buffer); if (!optional_packet.has_value()) { dbgln("Got an invalid mDNS packet"); return {}; @@ -175,7 +175,7 @@ ErrorOr> MulticastDNS::lookup(Name const& name, RecordType record auto buffer = TRY(receive(1024)); if (buffer.is_empty()) return Vector {}; - auto optional_packet = Packet::from_raw_packet(buffer.data(), buffer.size()); + auto optional_packet = Packet::from_raw_packet(buffer); if (!optional_packet.has_value()) { dbgln("Got an invalid mDNS packet"); continue;