Kernel: Add SocketHandle helper class that wraps locked sockets.

This allows us to have a comfy IPv4Socket::from_tcp_port() API that returns
a socket that's locked and safe to access. No need to worry about locking
at the client site.
This commit is contained in:
Andreas Kling 2019-03-14 09:19:24 +01:00
parent 3d5296a901
commit 54e7df0586
4 changed files with 106 additions and 22 deletions

View file

@ -27,6 +27,34 @@ Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_tcp_port()
return *s_map;
}
IPv4SocketHandle IPv4Socket::from_tcp_port(word port)
{
RetainPtr<IPv4Socket> socket;
{
LOCKER(sockets_by_tcp_port().lock());
auto it = sockets_by_tcp_port().resource().find(port);
if (it == sockets_by_tcp_port().resource().end())
return { };
socket = (*it).value;
ASSERT(socket);
}
return { move(socket) };
}
IPv4SocketHandle IPv4Socket::from_udp_port(word port)
{
RetainPtr<IPv4Socket> socket;
{
LOCKER(sockets_by_udp_port().lock());
auto it = sockets_by_udp_port().resource().find(port);
if (it == sockets_by_udp_port().resource().end())
return { };
socket = (*it).value;
ASSERT(socket);
}
return { move(socket) };
}
Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
{
static Lockable<HashTable<IPv4Socket*>>* s_table;
@ -217,8 +245,12 @@ NetworkOrdered<word> IPv4Socket::compute_tcp_checksum(const IPv4Address& source,
if (checksum > 0xffff)
checksum = (checksum >> 16) + (checksum & 0xffff);
}
if (payload_size & 1)
ASSERT_NOT_REACHED();
if (payload_size & 1) {
word expanded_byte = ((const byte*)packet.payload())[payload_size - 1];
checksum += expanded_byte;
if (checksum > 0xffff)
checksum = (checksum >> 16) + (checksum & 0xffff);
}
return ~(checksum & 0xffff);
}

View file

@ -7,6 +7,7 @@
#include <AK/Lock.h>
#include <AK/SinglyLinkedList.h>
class IPv4SocketHandle;
class NetworkAdapter;
class TCPPacket;
@ -28,6 +29,9 @@ public:
static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_udp_port();
static Lockable<HashMap<word, IPv4Socket*>>& sockets_by_tcp_port();
static IPv4SocketHandle from_tcp_port(word);
static IPv4SocketHandle from_udp_port(word);
virtual KResult bind(const sockaddr*, socklen_t) override;
virtual KResult connect(const sockaddr*, socklen_t) override;
virtual bool get_address(sockaddr*, socklen_t*) override;
@ -79,3 +83,26 @@ private:
bool m_can_read { false };
};
class IPv4SocketHandle : public SocketHandle {
public:
IPv4SocketHandle() { }
IPv4SocketHandle(RetainPtr<IPv4Socket>&& socket)
: SocketHandle(move(socket))
{
}
IPv4SocketHandle(IPv4SocketHandle&& other)
: SocketHandle(move(other))
{
}
IPv4SocketHandle(const IPv4SocketHandle&) = delete;
IPv4SocketHandle& operator=(const IPv4SocketHandle&) = delete;
IPv4Socket* operator->() { return &socket(); }
const IPv4Socket* operator->() const { return &socket(); }
IPv4Socket& socket() { return static_cast<IPv4Socket&>(SocketHandle::socket()); }
const IPv4Socket& socket() const { return static_cast<const IPv4Socket&>(SocketHandle::socket()); }
};

View file

@ -234,17 +234,12 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size)
);
#endif
RetainPtr<IPv4Socket> socket;
{
LOCKER(IPv4Socket::sockets_by_udp_port().lock());
auto it = IPv4Socket::sockets_by_udp_port().resource().find(udp_packet.destination_port());
if (it == IPv4Socket::sockets_by_udp_port().resource().end())
return;
ASSERT((*it).value);
socket = *(*it).value;
auto socket = IPv4Socket::from_udp_port(udp_packet.destination_port());
if (!socket) {
kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port());
return;
}
LOCKER(socket->lock());
ASSERT(socket->type() == SOCK_DGRAM);
ASSERT(socket->source_port() == udp_packet.destination_port());
socket->did_receive(ByteBuffer::copy((const byte*)&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
@ -280,19 +275,12 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
);
#endif
RetainPtr<IPv4Socket> socket;
{
LOCKER(IPv4Socket::sockets_by_tcp_port().lock());
auto it = IPv4Socket::sockets_by_tcp_port().resource().find(tcp_packet.destination_port());
if (it == IPv4Socket::sockets_by_tcp_port().resource().end()) {
kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
return;
}
ASSERT((*it).value);
socket = *(*it).value;
auto socket = IPv4Socket::from_tcp_port(tcp_packet.destination_port());
if (!socket) {
kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
return;
}
LOCKER(socket->lock());
ASSERT(socket->type() == SOCK_STREAM);
ASSERT(socket->source_port() == tcp_packet.destination_port());

View file

@ -76,3 +76,40 @@ private:
Vector<RetainPtr<Socket>> m_pending;
Vector<RetainPtr<Socket>> m_clients;
};
class SocketHandle {
public:
SocketHandle() { }
SocketHandle(RetainPtr<Socket>&& socket)
: m_socket(move(socket))
{
if (m_socket)
m_socket->lock().lock();
}
SocketHandle(SocketHandle&& other)
: m_socket(move(other.m_socket))
{
}
~SocketHandle()
{
if (m_socket)
m_socket->lock().unlock();
}
SocketHandle(const SocketHandle&) = delete;
SocketHandle& operator=(const SocketHandle&) = delete;
operator bool() const { return m_socket; }
Socket* operator->() { return &socket(); }
const Socket* operator->() const { return &socket(); }
Socket& socket() { return *m_socket; }
const Socket& socket() const { return *m_socket; }
private:
RetainPtr<Socket> m_socket;
};