Kernel: Refactor TCP/IP stack

This has several significant changes to the networking stack.

* Significant refactoring of the TCP state machine. Right now it's
  probably more fragile than it used to be, but handles quite a lot
  more of the handshake process.
* `TCPSocket` holds a `NetworkAdapter*`, assigned during `connect()` or
  `bind()`, whichever comes first.
* `listen()` is now virtual in `Socket` and intended to be implemented
  in its child classes
* `listen()` no longer works without `bind()` - this is a bit of a
  regression, but listening sockets didn't work at all before, so it's
  not possible to observe the regression.
* A file is exposed at `/proc/net_tcp`, which is a JSON document listing
  the current TCP sockets with a bit of metadata.
* There's an `ETHERNET_VERY_DEBUG` flag for dumping packet's content out
  to `kprintf`. It is, indeed, _very debug_.
This commit is contained in:
Conrad Pankoff 2019-08-06 23:40:38 +10:00 committed by Andreas Kling
parent c973a51a23
commit 73c998dbfc
12 changed files with 446 additions and 84 deletions

View file

@ -67,6 +67,7 @@ public:
}
in_addr_t to_in_addr_t() const { return m_data_as_u32; }
u32 to_u32() const { return m_data_as_u32; }
bool operator==(const IPv4Address& other) const { return m_data_as_u32 == other.m_data_as_u32; }
bool operator!=(const IPv4Address& other) const { return m_data_as_u32 != other.m_data_as_u32; }

View file

@ -14,6 +14,7 @@
#include <Kernel/FileSystem/VirtualFileSystem.h>
#include <Kernel/KParams.h>
#include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Net/TCPSocket.h>
#include <Kernel/PCI.h>
#include <Kernel/VM/MemoryManager.h>
#include <Kernel/kmalloc.h>
@ -46,6 +47,7 @@ enum ProcFileType {
FI_Root_uptime,
FI_Root_cmdline,
FI_Root_netadapters,
FI_Root_net_tcp,
FI_Root_self, // symlink
FI_Root_sys, // directory
__FI_Root_End,
@ -278,6 +280,23 @@ Optional<KBuffer> procfs$netadapters(InodeIdentifier)
return builder.to_byte_buffer();
}
Optional<KBuffer> procfs$net_tcp(InodeIdentifier)
{
JsonArray json;
TCPSocket::for_each([&json](auto& socket) {
JsonObject obj;
obj.set("local_address", socket->local_address().to_string());
obj.set("local_port", socket->local_port());
obj.set("peer_address", socket->peer_address().to_string());
obj.set("peer_port", socket->peer_port());
obj.set("state", TCPSocket::to_string(socket->state()));
obj.set("ack_number", socket->ack_number());
obj.set("sequence_number", socket->sequence_number());
json.append(obj);
});
return json.serialized().to_byte_buffer();
}
Optional<KBuffer> procfs$pid_vmo(InodeIdentifier identifier)
{
auto handle = ProcessInspectionHandle::from_pid(to_pid(identifier));
@ -1077,6 +1096,7 @@ ProcFS::ProcFS()
m_entries[FI_Root_uptime] = { "uptime", FI_Root_uptime, procfs$uptime };
m_entries[FI_Root_cmdline] = { "cmdline", FI_Root_cmdline, procfs$cmdline };
m_entries[FI_Root_netadapters] = { "netadapters", FI_Root_netadapters, procfs$netadapters };
m_entries[FI_Root_net_tcp] = { "net_tcp", FI_Root_net_tcp, procfs$net_tcp };
m_entries[FI_Root_sys] = { "sys", FI_Root_sys };
m_entries[FI_PID_vm] = { "vm", FI_PID_vm, procfs$pid_vm };

View file

@ -89,6 +89,22 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
return protocol_bind();
}
KResult IPv4Socket::listen(int backlog)
{
int rc = allocate_local_port_if_needed();
if (rc < 0)
return KResult(-EADDRINUSE);
if (m_local_address.to_u32() == 0)
return KResult(-EADDRINUSE);
set_backlog(backlog);
kprintf("IPv4Socket{%p} listening with backlog=%d\n", this, backlog);
return protocol_listen();
}
KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
{
if (address_size != sizeof(sockaddr_in))
@ -157,6 +173,9 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt
if (!adapter)
return -EHOSTUNREACH;
if (m_local_address.to_u32() == 0)
m_local_address = adapter->ipv4_address();
int rc = allocate_local_port_if_needed();
if (rc < 0)
return rc;

View file

@ -2,10 +2,11 @@
#include <AK/HashMap.h>
#include <AK/SinglyLinkedList.h>
#include <Kernel/KBuffer.h>
#include <Kernel/DoubleBuffer.h>
#include <Kernel/KBuffer.h>
#include <Kernel/Lock.h>
#include <Kernel/Net/IPv4.h>
#include <Kernel/Net/IPv4SocketTuple.h>
#include <Kernel/Net/Socket.h>
class IPv4SocketHandle;
@ -23,6 +24,7 @@ public:
virtual KResult bind(const sockaddr*, socklen_t) override;
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
virtual KResult listen(int) override;
virtual bool get_local_address(sockaddr*, socklen_t*) override;
virtual bool get_peer_address(sockaddr*, socklen_t*) override;
virtual void attach(FileDescription&) override;
@ -34,7 +36,7 @@ public:
void did_receive(const IPv4Address& peer_address, u16 peer_port, KBuffer&&);
const IPv4Address& local_address() const;
const IPv4Address& local_address() const { return m_local_address; }
u16 local_port() const { return m_local_port; }
void set_local_port(u16 port) { m_local_port = port; }
@ -42,6 +44,8 @@ public:
u16 peer_port() const { return m_peer_port; }
void set_peer_port(u16 port) { m_peer_port = port; }
IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); }
protected:
IPv4Socket(int type, int protocol);
virtual const char* class_name() const override { return "IPv4Socket"; }
@ -49,12 +53,16 @@ protected:
int allocate_local_port_if_needed();
virtual KResult protocol_bind() { return KSuccess; }
virtual KResult protocol_listen() { return KSuccess; }
virtual int protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; }
virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
virtual int protocol_allocate_local_port() { return 0; }
virtual bool protocol_is_disconnected() const { return false; }
void set_local_address(IPv4Address address) { m_local_address = address; }
void set_peer_address(IPv4Address address) { m_peer_address = address; }
private:
virtual bool is_ipv4() const override { return true; }

View file

@ -0,0 +1,63 @@
#pragma once
#include <AK/HashMap.h>
#include <AK/SinglyLinkedList.h>
#include <Kernel/DoubleBuffer.h>
#include <Kernel/KBuffer.h>
#include <Kernel/Lock.h>
#include <Kernel/Net/IPv4.h>
#include <Kernel/Net/Socket.h>
class IPv4SocketTuple {
public:
IPv4SocketTuple(IPv4Address local_address, u16 local_port, IPv4Address peer_address, u16 peer_port)
: m_local_address(local_address)
, m_local_port(local_port)
, m_peer_address(peer_address)
, m_peer_port(peer_port) {};
IPv4Address local_address() const { return m_local_address; };
u16 local_port() const { return m_local_port; };
IPv4Address peer_address() const { return m_peer_address; };
u16 peer_port() const { return m_peer_port; };
bool operator==(const IPv4SocketTuple other) const
{
return other.local_address() == m_local_address && other.local_port() == m_local_port && other.peer_address() == m_peer_address && other.peer_port() == m_peer_port;
};
String to_string() const
{
return String::format(
"%s:%d -> %s:%d",
m_local_address.to_string().characters(),
m_local_port,
m_peer_address.to_string().characters(),
m_peer_port);
}
private:
IPv4Address m_local_address;
u16 m_local_port { 0 };
IPv4Address m_peer_address;
u16 m_peer_port { 0 };
};
namespace AK {
template<>
struct Traits<IPv4SocketTuple> : public GenericTraits<IPv4SocketTuple> {
static unsigned hash(const IPv4SocketTuple& tuple)
{
auto h1 = pair_int_hash(tuple.local_address().to_u32(), tuple.local_port());
auto h2 = pair_int_hash(tuple.peer_address().to_u32(), tuple.peer_port());
return pair_int_hash(h1, h2);
}
static void dump(const IPv4SocketTuple& tuple)
{
kprintf("%s", tuple.to_string().characters());
}
};
}

View file

@ -114,6 +114,16 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre
return KSuccess;
}
KResult LocalSocket::listen(int backlog)
{
LOCKER(lock());
if (type() != SOCK_STREAM)
return KResult(-EOPNOTSUPP);
set_backlog(backlog);
kprintf("LocalSocket{%p} listening with backlog=%d\n", this, backlog);
return KSuccess;
}
void LocalSocket::attach(FileDescription& description)
{
switch (description.socket_role()) {

View file

@ -13,6 +13,7 @@ public:
// ^Socket
virtual KResult bind(const sockaddr*, socklen_t) override;
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
virtual KResult listen(int) override;
virtual bool get_local_address(sockaddr*, socklen_t*) override;
virtual bool get_peer_address(sockaddr*, socklen_t*) override;
virtual void attach(FileDescription&) override;

View file

@ -14,6 +14,7 @@
#include <Kernel/Process.h>
//#define ETHERNET_DEBUG
//#define ETHERNET_VERY_DEBUG
//#define IPV4_DEBUG
//#define ICMP_DEBUG
//#define UDP_DEBUG
@ -84,6 +85,28 @@ void NetworkTask_main()
packet.size());
#endif
#ifdef ETHERNET_VERY_DEBUG
u8* data = packet.data();
for (size_t i = 0; i < packet.size(); i++) {
kprintf("%b", data[i]);
switch (i % 16) {
case 7:
kprintf(" ");
break;
case 15:
kprintf("\n");
break;
default:
kprintf(" ");
break;
}
}
kprintf("\n");
#endif
switch (eth.ether_type()) {
case EtherType::ARP:
handle_arp(eth, packet.size());
@ -279,7 +302,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size();
#ifdef TCP_DEBUG
kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s %s), window_size=%u, payload_size=%u\n",
kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s%s%s%s), window_size=%u, payload_size=%u\n",
ipv4_packet.source().to_string().characters(),
tcp_packet.source_port(),
ipv4_packet.destination().to_string().characters(),
@ -287,15 +310,19 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
tcp_packet.sequence_number(),
tcp_packet.ack_number(),
tcp_packet.flags(),
tcp_packet.has_syn() ? "SYN" : "",
tcp_packet.has_ack() ? "ACK" : "",
tcp_packet.has_syn() ? "SYN " : "",
tcp_packet.has_ack() ? "ACK " : "",
tcp_packet.has_fin() ? "FIN " : "",
tcp_packet.has_rst() ? "RST " : "",
tcp_packet.window_size(),
payload_size);
#endif
auto socket = TCPSocket::from_port(tcp_packet.destination_port());
IPv4SocketTuple tuple(ipv4_packet.destination(), tcp_packet.destination_port(), ipv4_packet.source(), tcp_packet.source_port());
auto socket = TCPSocket::from_tuple(tuple);
if (!socket) {
kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
kprintf("handle_tcp: No TCP socket for tuple %s\n", tuple.to_string().characters());
return;
}
@ -307,39 +334,168 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
return;
}
if (tcp_packet.has_syn() && tcp_packet.has_ack()) {
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::ACK);
socket->set_connected(true);
kprintf("handle_tcp: Connection established!\n");
socket->set_state(TCPSocket::State::Connected);
return;
}
#ifdef TCP_DEBUG
kprintf("handle_tcp: state=%s\n", TCPSocket::to_string(socket->state()));
#endif
if (tcp_packet.has_fin()) {
kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size);
switch (socket->state()) {
case TCPSocket::State::Closed:
kprintf("handle_tcp: unexpected flags in Closed state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: Closed -> Closed\n");
return;
case TCPSocket::State::TimeWait:
kprintf("handle_tcp: unexpected flags in TimeWait state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: TimeWait -> Closed\n");
return;
case TCPSocket::State::Listen:
switch (tcp_packet.flags()) {
case TCPFlags::SYN:
kprintf("handle_tcp: incoming connections not supported\n");
// socket->send_tcp_packet(TCPFlags::RST);
return;
default:
kprintf("handle_tcp: unexpected flags in Listen state\n");
// socket->send_tcp_packet(TCPFlags::RST);
return;
}
case TCPSocket::State::SynSent:
switch (tcp_packet.flags()) {
case TCPFlags::SYN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::SynReceived);
kprintf("handle_tcp: SynSent -> SynReceived\n");
return;
case TCPFlags::SYN | TCPFlags::ACK:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::Established);
socket->set_connected(true);
kprintf("handle_tcp: SynSent -> Established\n");
return;
default:
kprintf("handle_tcp: unexpected flags in SynSent state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: SynSent -> Closed\n");
return;
}
case TCPSocket::State::SynReceived:
switch (tcp_packet.flags()) {
case TCPFlags::ACK:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::Established);
socket->set_connected(true);
kprintf("handle_tcp: SynReceived -> Established\n");
return;
default:
kprintf("handle_tcp: unexpected flags in SynReceived state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: SynReceived -> Closed\n");
return;
}
case TCPSocket::State::CloseWait:
switch (tcp_packet.flags()) {
default:
kprintf("handle_tcp: unexpected flags in CloseWait state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: CloseWait -> Closed\n");
return;
}
case TCPSocket::State::LastAck:
switch (tcp_packet.flags()) {
case TCPFlags::ACK:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: LastAck -> Closed\n");
return;
default:
kprintf("handle_tcp: unexpected flags in LastAck state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: LastAck -> Closed\n");
return;
}
case TCPSocket::State::FinWait1:
switch (tcp_packet.flags()) {
case TCPFlags::ACK:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::FinWait2);
kprintf("handle_tcp: FinWait1 -> FinWait2\n");
return;
case TCPFlags::FIN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::Closing);
kprintf("handle_tcp: FinWait1 -> Closing\n");
return;
default:
kprintf("handle_tcp: unexpected flags in FinWait1 state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: FinWait1 -> Closed\n");
return;
}
case TCPSocket::State::FinWait2:
switch (tcp_packet.flags()) {
case TCPFlags::FIN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::TimeWait);
kprintf("handle_tcp: FinWait2 -> TimeWait\n");
return;
default:
kprintf("handle_tcp: unexpected flags in FinWait2 state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: FinWait2 -> Closed\n");
return;
}
case TCPSocket::State::Closing:
switch (tcp_packet.flags()) {
case TCPFlags::ACK:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::TimeWait);
kprintf("handle_tcp: Closing -> TimeWait\n");
return;
default:
kprintf("handle_tcp: unexpected flags in Closing state\n");
socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
kprintf("handle_tcp: Closing -> Closed\n");
return;
}
case TCPSocket::State::Established:
if (tcp_packet.has_fin()) {
if (payload_size != 0)
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::CloseWait);
socket->set_connected(false);
kprintf("handle_tcp: Established -> CloseWait\n");
return;
}
socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
#ifdef TCP_DEBUG
kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
tcp_packet.ack_number(),
tcp_packet.sequence_number(),
payload_size,
socket->ack_number(),
socket->sequence_number());
#endif
socket->send_tcp_packet(TCPFlags::ACK);
if (payload_size != 0)
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
socket->set_state(TCPSocket::State::Disconnecting);
socket->set_connected(false);
return;
}
socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
#ifdef TCP_DEBUG
kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
tcp_packet.ack_number(),
tcp_packet.sequence_number(),
payload_size,
socket->ack_number(),
socket->sequence_number());
#endif
socket->send_tcp_packet(TCPFlags::ACK);
if (payload_size != 0)
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
}

View file

@ -1,8 +1,8 @@
#pragma once
#include <AK/HashTable.h>
#include <AK/RefPtr.h>
#include <AK/RefCounted.h>
#include <AK/RefPtr.h>
#include <AK/Vector.h>
#include <Kernel/FileSystem/File.h>
#include <Kernel/KResult.h>
@ -35,10 +35,10 @@ public:
bool can_accept() const { return !m_pending.is_empty(); }
RefPtr<Socket> accept();
bool is_connected() const { return m_connected; }
KResult listen(int backlog);
virtual KResult bind(const sockaddr*, socklen_t) = 0;
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0;
virtual KResult listen(int) = 0;
virtual bool get_local_address(sockaddr*, socklen_t*) = 0;
virtual bool get_peer_address(sockaddr*, socklen_t*) = 0;
virtual bool is_local() const { return false; }
@ -73,6 +73,9 @@ protected:
void load_receive_deadline();
void load_send_deadline();
int backlog() const { return m_backlog; }
void set_backlog(int backlog) { m_backlog = backlog; }
virtual const char* class_name() const override { return "Socket"; }
private:

View file

@ -39,6 +39,7 @@ public:
bool has_syn() const { return flags() & TCPFlags::SYN; }
bool has_ack() const { return flags() & TCPFlags::ACK; }
bool has_fin() const { return flags() & TCPFlags::FIN; }
bool has_rst() const { return flags() & TCPFlags::RST; }
u8 data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; }
void set_data_offset(u16 data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; }

View file

@ -1,28 +1,35 @@
#include <Kernel/Devices/RandomDevice.h>
#include <Kernel/FileSystem/FileDescription.h>
#include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Net/Routing.h>
#include <Kernel/Net/TCP.h>
#include <Kernel/Net/TCPSocket.h>
#include <Kernel/FileSystem/FileDescription.h>
#include <Kernel/Process.h>
//#define TCP_SOCKET_DEBUG
Lockable<HashMap<u16, TCPSocket*>>& TCPSocket::sockets_by_port()
void TCPSocket::for_each(Function<void(TCPSocket*&)> callback)
{
static Lockable<HashMap<u16, TCPSocket*>>* s_map;
LOCKER(sockets_by_tuple().lock());
for (auto& it : sockets_by_tuple().resource())
callback(it.value);
}
Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple()
{
static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>* s_map;
if (!s_map)
s_map = new Lockable<HashMap<u16, TCPSocket*>>;
s_map = new Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>;
return *s_map;
}
TCPSocketHandle TCPSocket::from_port(u16 port)
TCPSocketHandle TCPSocket::from_tuple(const IPv4SocketTuple& tuple)
{
RefPtr<TCPSocket> socket;
{
LOCKER(sockets_by_port().lock());
auto it = sockets_by_port().resource().find(port);
if (it == sockets_by_port().resource().end())
LOCKER(sockets_by_tuple().lock());
auto it = sockets_by_tuple().resource().find(tuple);
if (it == sockets_by_tuple().resource().end())
return {};
socket = (*it).value;
ASSERT(socket);
@ -30,6 +37,11 @@ TCPSocketHandle TCPSocket::from_port(u16 port)
return { move(socket) };
}
TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port)
{
return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port));
}
TCPSocket::TCPSocket(int protocol)
: IPv4Socket(SOCK_STREAM, protocol)
{
@ -37,8 +49,8 @@ TCPSocket::TCPSocket(int protocol)
TCPSocket::~TCPSocket()
{
LOCKER(sockets_by_port().lock());
sockets_by_port().resource().remove(local_port());
LOCKER(sockets_by_tuple().lock());
sockets_by_tuple().resource().remove(tuple());
}
NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol)
@ -62,18 +74,13 @@ int TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size
int TCPSocket::protocol_send(const void* data, int data_length)
{
auto* adapter = adapter_for_route_to(peer_address());
if (!adapter)
return -EHOSTUNREACH;
send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length);
return data_length;
}
void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size)
{
// FIXME: Maybe the socket should be bound to an adapter instead of looking it up every time?
auto* adapter = adapter_for_route_to(peer_address());
ASSERT(adapter);
ASSERT(m_adapter);
auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size);
auto& tcp_packet = *(TCPPacket*)(buffer.pointer());
@ -95,19 +102,21 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size
}
memcpy(tcp_packet.payload(), payload, payload_size);
tcp_packet.set_checksum(compute_tcp_checksum(adapter->ipv4_address(), peer_address(), tcp_packet, payload_size));
tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size));
#ifdef TCP_SOCKET_DEBUG
kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n",
adapter->ipv4_address().to_string().characters(),
kprintf("sending tcp packet from %s:%u to %s:%u with (%s%s%s%s) seq_no=%u, ack_no=%u\n",
local_address().to_string().characters(),
local_port(),
peer_address().to_string().characters(),
peer_port(),
tcp_packet.has_syn() ? "SYN" : "",
tcp_packet.has_ack() ? "ACK" : "",
tcp_packet.has_fin() ? "FIN" : "",
tcp_packet.has_rst() ? "RST" : "",
tcp_packet.sequence_number(),
tcp_packet.ack_number());
#endif
adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size());
m_adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size());
}
NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size)
@ -152,11 +161,36 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, c
return ~(checksum & 0xffff);
}
KResult TCPSocket::protocol_bind()
{
if (!m_adapter) {
m_adapter = NetworkAdapter::from_ipv4_address(local_address());
if (!m_adapter)
return KResult(-EADDRNOTAVAIL);
}
return KSuccess;
}
KResult TCPSocket::protocol_listen()
{
LOCKER(sockets_by_tuple().lock());
if (sockets_by_tuple().resource().contains(tuple()))
return KResult(-EADDRINUSE);
sockets_by_tuple().resource().set(tuple(), this);
set_state(State::Listen);
return KSuccess;
}
KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block)
{
auto* adapter = adapter_for_route_to(peer_address());
if (!adapter)
return KResult(-EHOSTUNREACH);
if (!m_adapter) {
m_adapter = adapter_for_route_to(peer_address());
if (!m_adapter)
return KResult(-EHOSTUNREACH);
set_local_address(m_adapter->ipv4_address());
}
allocate_local_port_if_needed();
@ -164,7 +198,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
m_ack_number = 0;
send_tcp_packet(TCPFlags::SYN);
m_state = State::Connecting;
m_state = State::SynSent;
if (should_block == ShouldBlock::Yes) {
if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal)
@ -183,12 +217,14 @@ int TCPSocket::protocol_allocate_local_port()
static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + RandomDevice::random_value() % ephemeral_port_range_size;
LOCKER(sockets_by_port().lock());
LOCKER(sockets_by_tuple().lock());
for (u16 port = first_scan_port;;) {
auto it = sockets_by_port().resource().find(port);
if (it == sockets_by_port().resource().end()) {
IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port());
auto it = sockets_by_tuple().resource().find(proposed_tuple);
if (it == sockets_by_tuple().resource().end()) {
set_local_port(port);
sockets_by_port().resource().set(port, this);
sockets_by_tuple().resource().set(proposed_tuple, this);
return port;
}
++port;
@ -202,14 +238,16 @@ int TCPSocket::protocol_allocate_local_port()
bool TCPSocket::protocol_is_disconnected() const
{
return m_state == State::Disconnecting || m_state == State::Disconnected;
}
KResult TCPSocket::protocol_bind()
{
LOCKER(sockets_by_port().lock());
if (sockets_by_port().resource().contains(local_port()))
return KResult(-EADDRINUSE);
sockets_by_port().resource().set(local_port(), this);
return KSuccess;
switch (m_state) {
case State::Closed:
case State::CloseWait:
case State::LastAck:
case State::FinWait1:
case State::FinWait2:
case State::Closing:
case State::TimeWait:
return true;
default:
return false;
}
}

View file

@ -1,19 +1,58 @@
#pragma once
#include <AK/Function.h>
#include <Kernel/Net/IPv4Socket.h>
class TCPSocket final : public IPv4Socket {
public:
static void for_each(Function<void(TCPSocket*&)>);
static NonnullRefPtr<TCPSocket> create(int protocol);
virtual ~TCPSocket() override;
enum class State {
Disconnected,
Connecting,
Connected,
Disconnecting,
Closed,
Listen,
SynSent,
SynReceived,
Established,
CloseWait,
LastAck,
FinWait1,
FinWait2,
Closing,
TimeWait,
};
static const char* to_string(State state)
{
switch (state) {
case State::Closed:
return "Closed";
case State::Listen:
return "Listen";
case State::SynSent:
return "SynSent";
case State::SynReceived:
return "SynReceived";
case State::Established:
return "Established";
case State::CloseWait:
return "CloseWait";
case State::LastAck:
return "LastAck";
case State::FinWait1:
return "FinWait1";
case State::FinWait2:
return "FinWait2";
case State::Closing:
return "Closing";
case State::TimeWait:
return "TimeWait";
default:
return "None";
}
}
State state() const { return m_state; }
void set_state(State state) { m_state = state; }
@ -24,8 +63,9 @@ public:
void send_tcp_packet(u16 flags, const void* = nullptr, int = 0);
static Lockable<HashMap<u16, TCPSocket*>>& sockets_by_port();
static TCPSocketHandle from_port(u16);
static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple();
static TCPSocketHandle from_tuple(const IPv4SocketTuple& tuple);
static TCPSocketHandle from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
private:
explicit TCPSocket(int protocol);
@ -39,10 +79,12 @@ private:
virtual int protocol_allocate_local_port() override;
virtual bool protocol_is_disconnected() const override;
virtual KResult protocol_bind() override;
virtual KResult protocol_listen() override;
NetworkAdapter* m_adapter { nullptr };
u32 m_sequence_number { 0 };
u32 m_ack_number { 0 };
State m_state { State::Disconnected };
State m_state { State::Closed };
};
class TCPSocketHandle : public SocketHandle {