mirror of
https://github.com/SerenityOS/serenity
synced 2024-10-15 12:23:15 +00:00
Kernel: Refactor TCP/IP stack
This has several significant changes to the networking stack. * Significant refactoring of the TCP state machine. Right now it's probably more fragile than it used to be, but handles quite a lot more of the handshake process. * `TCPSocket` holds a `NetworkAdapter*`, assigned during `connect()` or `bind()`, whichever comes first. * `listen()` is now virtual in `Socket` and intended to be implemented in its child classes * `listen()` no longer works without `bind()` - this is a bit of a regression, but listening sockets didn't work at all before, so it's not possible to observe the regression. * A file is exposed at `/proc/net_tcp`, which is a JSON document listing the current TCP sockets with a bit of metadata. * There's an `ETHERNET_VERY_DEBUG` flag for dumping packet's content out to `kprintf`. It is, indeed, _very debug_.
This commit is contained in:
parent
c973a51a23
commit
73c998dbfc
|
@ -67,6 +67,7 @@ public:
|
|||
}
|
||||
|
||||
in_addr_t to_in_addr_t() const { return m_data_as_u32; }
|
||||
u32 to_u32() const { return m_data_as_u32; }
|
||||
|
||||
bool operator==(const IPv4Address& other) const { return m_data_as_u32 == other.m_data_as_u32; }
|
||||
bool operator!=(const IPv4Address& other) const { return m_data_as_u32 != other.m_data_as_u32; }
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <Kernel/FileSystem/VirtualFileSystem.h>
|
||||
#include <Kernel/KParams.h>
|
||||
#include <Kernel/Net/NetworkAdapter.h>
|
||||
#include <Kernel/Net/TCPSocket.h>
|
||||
#include <Kernel/PCI.h>
|
||||
#include <Kernel/VM/MemoryManager.h>
|
||||
#include <Kernel/kmalloc.h>
|
||||
|
@ -46,6 +47,7 @@ enum ProcFileType {
|
|||
FI_Root_uptime,
|
||||
FI_Root_cmdline,
|
||||
FI_Root_netadapters,
|
||||
FI_Root_net_tcp,
|
||||
FI_Root_self, // symlink
|
||||
FI_Root_sys, // directory
|
||||
__FI_Root_End,
|
||||
|
@ -278,6 +280,23 @@ Optional<KBuffer> procfs$netadapters(InodeIdentifier)
|
|||
return builder.to_byte_buffer();
|
||||
}
|
||||
|
||||
Optional<KBuffer> procfs$net_tcp(InodeIdentifier)
|
||||
{
|
||||
JsonArray json;
|
||||
TCPSocket::for_each([&json](auto& socket) {
|
||||
JsonObject obj;
|
||||
obj.set("local_address", socket->local_address().to_string());
|
||||
obj.set("local_port", socket->local_port());
|
||||
obj.set("peer_address", socket->peer_address().to_string());
|
||||
obj.set("peer_port", socket->peer_port());
|
||||
obj.set("state", TCPSocket::to_string(socket->state()));
|
||||
obj.set("ack_number", socket->ack_number());
|
||||
obj.set("sequence_number", socket->sequence_number());
|
||||
json.append(obj);
|
||||
});
|
||||
return json.serialized().to_byte_buffer();
|
||||
}
|
||||
|
||||
Optional<KBuffer> procfs$pid_vmo(InodeIdentifier identifier)
|
||||
{
|
||||
auto handle = ProcessInspectionHandle::from_pid(to_pid(identifier));
|
||||
|
@ -1077,6 +1096,7 @@ ProcFS::ProcFS()
|
|||
m_entries[FI_Root_uptime] = { "uptime", FI_Root_uptime, procfs$uptime };
|
||||
m_entries[FI_Root_cmdline] = { "cmdline", FI_Root_cmdline, procfs$cmdline };
|
||||
m_entries[FI_Root_netadapters] = { "netadapters", FI_Root_netadapters, procfs$netadapters };
|
||||
m_entries[FI_Root_net_tcp] = { "net_tcp", FI_Root_net_tcp, procfs$net_tcp };
|
||||
m_entries[FI_Root_sys] = { "sys", FI_Root_sys };
|
||||
|
||||
m_entries[FI_PID_vm] = { "vm", FI_PID_vm, procfs$pid_vm };
|
||||
|
|
|
@ -89,6 +89,22 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
|
|||
return protocol_bind();
|
||||
}
|
||||
|
||||
KResult IPv4Socket::listen(int backlog)
|
||||
{
|
||||
int rc = allocate_local_port_if_needed();
|
||||
if (rc < 0)
|
||||
return KResult(-EADDRINUSE);
|
||||
|
||||
if (m_local_address.to_u32() == 0)
|
||||
return KResult(-EADDRINUSE);
|
||||
|
||||
set_backlog(backlog);
|
||||
|
||||
kprintf("IPv4Socket{%p} listening with backlog=%d\n", this, backlog);
|
||||
|
||||
return protocol_listen();
|
||||
}
|
||||
|
||||
KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
|
||||
{
|
||||
if (address_size != sizeof(sockaddr_in))
|
||||
|
@ -157,6 +173,9 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt
|
|||
if (!adapter)
|
||||
return -EHOSTUNREACH;
|
||||
|
||||
if (m_local_address.to_u32() == 0)
|
||||
m_local_address = adapter->ipv4_address();
|
||||
|
||||
int rc = allocate_local_port_if_needed();
|
||||
if (rc < 0)
|
||||
return rc;
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
|
||||
#include <AK/HashMap.h>
|
||||
#include <AK/SinglyLinkedList.h>
|
||||
#include <Kernel/KBuffer.h>
|
||||
#include <Kernel/DoubleBuffer.h>
|
||||
#include <Kernel/KBuffer.h>
|
||||
#include <Kernel/Lock.h>
|
||||
#include <Kernel/Net/IPv4.h>
|
||||
#include <Kernel/Net/IPv4SocketTuple.h>
|
||||
#include <Kernel/Net/Socket.h>
|
||||
|
||||
class IPv4SocketHandle;
|
||||
|
@ -23,6 +24,7 @@ public:
|
|||
|
||||
virtual KResult bind(const sockaddr*, socklen_t) override;
|
||||
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
|
||||
virtual KResult listen(int) override;
|
||||
virtual bool get_local_address(sockaddr*, socklen_t*) override;
|
||||
virtual bool get_peer_address(sockaddr*, socklen_t*) override;
|
||||
virtual void attach(FileDescription&) override;
|
||||
|
@ -34,7 +36,7 @@ public:
|
|||
|
||||
void did_receive(const IPv4Address& peer_address, u16 peer_port, KBuffer&&);
|
||||
|
||||
const IPv4Address& local_address() const;
|
||||
const IPv4Address& local_address() const { return m_local_address; }
|
||||
u16 local_port() const { return m_local_port; }
|
||||
void set_local_port(u16 port) { m_local_port = port; }
|
||||
|
||||
|
@ -42,6 +44,8 @@ public:
|
|||
u16 peer_port() const { return m_peer_port; }
|
||||
void set_peer_port(u16 port) { m_peer_port = port; }
|
||||
|
||||
IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); }
|
||||
|
||||
protected:
|
||||
IPv4Socket(int type, int protocol);
|
||||
virtual const char* class_name() const override { return "IPv4Socket"; }
|
||||
|
@ -49,12 +53,16 @@ protected:
|
|||
int allocate_local_port_if_needed();
|
||||
|
||||
virtual KResult protocol_bind() { return KSuccess; }
|
||||
virtual KResult protocol_listen() { return KSuccess; }
|
||||
virtual int protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; }
|
||||
virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
|
||||
virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
|
||||
virtual int protocol_allocate_local_port() { return 0; }
|
||||
virtual bool protocol_is_disconnected() const { return false; }
|
||||
|
||||
void set_local_address(IPv4Address address) { m_local_address = address; }
|
||||
void set_peer_address(IPv4Address address) { m_peer_address = address; }
|
||||
|
||||
private:
|
||||
virtual bool is_ipv4() const override { return true; }
|
||||
|
||||
|
|
63
Kernel/Net/IPv4SocketTuple.h
Normal file
63
Kernel/Net/IPv4SocketTuple.h
Normal file
|
@ -0,0 +1,63 @@
|
|||
#pragma once
|
||||
|
||||
#include <AK/HashMap.h>
|
||||
#include <AK/SinglyLinkedList.h>
|
||||
#include <Kernel/DoubleBuffer.h>
|
||||
#include <Kernel/KBuffer.h>
|
||||
#include <Kernel/Lock.h>
|
||||
#include <Kernel/Net/IPv4.h>
|
||||
#include <Kernel/Net/Socket.h>
|
||||
|
||||
class IPv4SocketTuple {
|
||||
public:
|
||||
IPv4SocketTuple(IPv4Address local_address, u16 local_port, IPv4Address peer_address, u16 peer_port)
|
||||
: m_local_address(local_address)
|
||||
, m_local_port(local_port)
|
||||
, m_peer_address(peer_address)
|
||||
, m_peer_port(peer_port) {};
|
||||
|
||||
IPv4Address local_address() const { return m_local_address; };
|
||||
u16 local_port() const { return m_local_port; };
|
||||
IPv4Address peer_address() const { return m_peer_address; };
|
||||
u16 peer_port() const { return m_peer_port; };
|
||||
|
||||
bool operator==(const IPv4SocketTuple other) const
|
||||
{
|
||||
return other.local_address() == m_local_address && other.local_port() == m_local_port && other.peer_address() == m_peer_address && other.peer_port() == m_peer_port;
|
||||
};
|
||||
|
||||
String to_string() const
|
||||
{
|
||||
return String::format(
|
||||
"%s:%d -> %s:%d",
|
||||
m_local_address.to_string().characters(),
|
||||
m_local_port,
|
||||
m_peer_address.to_string().characters(),
|
||||
m_peer_port);
|
||||
}
|
||||
|
||||
private:
|
||||
IPv4Address m_local_address;
|
||||
u16 m_local_port { 0 };
|
||||
IPv4Address m_peer_address;
|
||||
u16 m_peer_port { 0 };
|
||||
};
|
||||
|
||||
namespace AK {
|
||||
|
||||
template<>
|
||||
struct Traits<IPv4SocketTuple> : public GenericTraits<IPv4SocketTuple> {
|
||||
static unsigned hash(const IPv4SocketTuple& tuple)
|
||||
{
|
||||
auto h1 = pair_int_hash(tuple.local_address().to_u32(), tuple.local_port());
|
||||
auto h2 = pair_int_hash(tuple.peer_address().to_u32(), tuple.peer_port());
|
||||
return pair_int_hash(h1, h2);
|
||||
}
|
||||
|
||||
static void dump(const IPv4SocketTuple& tuple)
|
||||
{
|
||||
kprintf("%s", tuple.to_string().characters());
|
||||
}
|
||||
};
|
||||
|
||||
}
|
|
@ -114,6 +114,16 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre
|
|||
return KSuccess;
|
||||
}
|
||||
|
||||
KResult LocalSocket::listen(int backlog)
|
||||
{
|
||||
LOCKER(lock());
|
||||
if (type() != SOCK_STREAM)
|
||||
return KResult(-EOPNOTSUPP);
|
||||
set_backlog(backlog);
|
||||
kprintf("LocalSocket{%p} listening with backlog=%d\n", this, backlog);
|
||||
return KSuccess;
|
||||
}
|
||||
|
||||
void LocalSocket::attach(FileDescription& description)
|
||||
{
|
||||
switch (description.socket_role()) {
|
||||
|
|
|
@ -13,6 +13,7 @@ public:
|
|||
// ^Socket
|
||||
virtual KResult bind(const sockaddr*, socklen_t) override;
|
||||
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
|
||||
virtual KResult listen(int) override;
|
||||
virtual bool get_local_address(sockaddr*, socklen_t*) override;
|
||||
virtual bool get_peer_address(sockaddr*, socklen_t*) override;
|
||||
virtual void attach(FileDescription&) override;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <Kernel/Process.h>
|
||||
|
||||
//#define ETHERNET_DEBUG
|
||||
//#define ETHERNET_VERY_DEBUG
|
||||
//#define IPV4_DEBUG
|
||||
//#define ICMP_DEBUG
|
||||
//#define UDP_DEBUG
|
||||
|
@ -84,6 +85,28 @@ void NetworkTask_main()
|
|||
packet.size());
|
||||
#endif
|
||||
|
||||
#ifdef ETHERNET_VERY_DEBUG
|
||||
u8* data = packet.data();
|
||||
|
||||
for (size_t i = 0; i < packet.size(); i++) {
|
||||
kprintf("%b", data[i]);
|
||||
|
||||
switch (i % 16) {
|
||||
case 7:
|
||||
kprintf(" ");
|
||||
break;
|
||||
case 15:
|
||||
kprintf("\n");
|
||||
break;
|
||||
default:
|
||||
kprintf(" ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
kprintf("\n");
|
||||
#endif
|
||||
|
||||
switch (eth.ether_type()) {
|
||||
case EtherType::ARP:
|
||||
handle_arp(eth, packet.size());
|
||||
|
@ -279,7 +302,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
|
|||
size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size();
|
||||
|
||||
#ifdef TCP_DEBUG
|
||||
kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s %s), window_size=%u, payload_size=%u\n",
|
||||
kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s%s%s%s), window_size=%u, payload_size=%u\n",
|
||||
ipv4_packet.source().to_string().characters(),
|
||||
tcp_packet.source_port(),
|
||||
ipv4_packet.destination().to_string().characters(),
|
||||
|
@ -287,15 +310,19 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
|
|||
tcp_packet.sequence_number(),
|
||||
tcp_packet.ack_number(),
|
||||
tcp_packet.flags(),
|
||||
tcp_packet.has_syn() ? "SYN" : "",
|
||||
tcp_packet.has_ack() ? "ACK" : "",
|
||||
tcp_packet.has_syn() ? "SYN " : "",
|
||||
tcp_packet.has_ack() ? "ACK " : "",
|
||||
tcp_packet.has_fin() ? "FIN " : "",
|
||||
tcp_packet.has_rst() ? "RST " : "",
|
||||
tcp_packet.window_size(),
|
||||
payload_size);
|
||||
#endif
|
||||
|
||||
auto socket = TCPSocket::from_port(tcp_packet.destination_port());
|
||||
IPv4SocketTuple tuple(ipv4_packet.destination(), tcp_packet.destination_port(), ipv4_packet.source(), tcp_packet.source_port());
|
||||
|
||||
auto socket = TCPSocket::from_tuple(tuple);
|
||||
if (!socket) {
|
||||
kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port());
|
||||
kprintf("handle_tcp: No TCP socket for tuple %s\n", tuple.to_string().characters());
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -307,39 +334,168 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
|
|||
return;
|
||||
}
|
||||
|
||||
if (tcp_packet.has_syn() && tcp_packet.has_ack()) {
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->send_tcp_packet(TCPFlags::ACK);
|
||||
socket->set_connected(true);
|
||||
kprintf("handle_tcp: Connection established!\n");
|
||||
socket->set_state(TCPSocket::State::Connected);
|
||||
return;
|
||||
}
|
||||
#ifdef TCP_DEBUG
|
||||
kprintf("handle_tcp: state=%s\n", TCPSocket::to_string(socket->state()));
|
||||
#endif
|
||||
|
||||
if (tcp_packet.has_fin()) {
|
||||
kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size);
|
||||
switch (socket->state()) {
|
||||
case TCPSocket::State::Closed:
|
||||
kprintf("handle_tcp: unexpected flags in Closed state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: Closed -> Closed\n");
|
||||
return;
|
||||
case TCPSocket::State::TimeWait:
|
||||
kprintf("handle_tcp: unexpected flags in TimeWait state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: TimeWait -> Closed\n");
|
||||
return;
|
||||
case TCPSocket::State::Listen:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::SYN:
|
||||
kprintf("handle_tcp: incoming connections not supported\n");
|
||||
// socket->send_tcp_packet(TCPFlags::RST);
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in Listen state\n");
|
||||
// socket->send_tcp_packet(TCPFlags::RST);
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::SynSent:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::SYN:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->send_tcp_packet(TCPFlags::ACK);
|
||||
socket->set_state(TCPSocket::State::SynReceived);
|
||||
kprintf("handle_tcp: SynSent -> SynReceived\n");
|
||||
return;
|
||||
case TCPFlags::SYN | TCPFlags::ACK:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->send_tcp_packet(TCPFlags::ACK);
|
||||
socket->set_state(TCPSocket::State::Established);
|
||||
socket->set_connected(true);
|
||||
kprintf("handle_tcp: SynSent -> Established\n");
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in SynSent state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: SynSent -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::SynReceived:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::ACK:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->set_state(TCPSocket::State::Established);
|
||||
socket->set_connected(true);
|
||||
kprintf("handle_tcp: SynReceived -> Established\n");
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in SynReceived state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: SynReceived -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::CloseWait:
|
||||
switch (tcp_packet.flags()) {
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in CloseWait state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: CloseWait -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::LastAck:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::ACK:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: LastAck -> Closed\n");
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in LastAck state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: LastAck -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::FinWait1:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::ACK:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->set_state(TCPSocket::State::FinWait2);
|
||||
kprintf("handle_tcp: FinWait1 -> FinWait2\n");
|
||||
return;
|
||||
case TCPFlags::FIN:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->set_state(TCPSocket::State::Closing);
|
||||
kprintf("handle_tcp: FinWait1 -> Closing\n");
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in FinWait1 state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: FinWait1 -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::FinWait2:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::FIN:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->set_state(TCPSocket::State::TimeWait);
|
||||
kprintf("handle_tcp: FinWait2 -> TimeWait\n");
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in FinWait2 state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: FinWait2 -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::Closing:
|
||||
switch (tcp_packet.flags()) {
|
||||
case TCPFlags::ACK:
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->set_state(TCPSocket::State::TimeWait);
|
||||
kprintf("handle_tcp: Closing -> TimeWait\n");
|
||||
return;
|
||||
default:
|
||||
kprintf("handle_tcp: unexpected flags in Closing state\n");
|
||||
socket->send_tcp_packet(TCPFlags::RST);
|
||||
socket->set_state(TCPSocket::State::Closed);
|
||||
kprintf("handle_tcp: Closing -> Closed\n");
|
||||
return;
|
||||
}
|
||||
case TCPSocket::State::Established:
|
||||
if (tcp_packet.has_fin()) {
|
||||
if (payload_size != 0)
|
||||
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
|
||||
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->send_tcp_packet(TCPFlags::ACK);
|
||||
socket->set_state(TCPSocket::State::CloseWait);
|
||||
socket->set_connected(false);
|
||||
kprintf("handle_tcp: Established -> CloseWait\n");
|
||||
return;
|
||||
}
|
||||
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
|
||||
|
||||
#ifdef TCP_DEBUG
|
||||
kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
|
||||
tcp_packet.ack_number(),
|
||||
tcp_packet.sequence_number(),
|
||||
payload_size,
|
||||
socket->ack_number(),
|
||||
socket->sequence_number());
|
||||
#endif
|
||||
|
||||
socket->send_tcp_packet(TCPFlags::ACK);
|
||||
|
||||
if (payload_size != 0)
|
||||
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
|
||||
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
|
||||
socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
|
||||
socket->set_state(TCPSocket::State::Disconnecting);
|
||||
socket->set_connected(false);
|
||||
return;
|
||||
}
|
||||
|
||||
socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
|
||||
#ifdef TCP_DEBUG
|
||||
kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
|
||||
tcp_packet.ack_number(),
|
||||
tcp_packet.sequence_number(),
|
||||
payload_size,
|
||||
socket->ack_number(),
|
||||
socket->sequence_number());
|
||||
#endif
|
||||
socket->send_tcp_packet(TCPFlags::ACK);
|
||||
|
||||
if (payload_size != 0)
|
||||
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <AK/HashTable.h>
|
||||
#include <AK/RefPtr.h>
|
||||
#include <AK/RefCounted.h>
|
||||
#include <AK/RefPtr.h>
|
||||
#include <AK/Vector.h>
|
||||
#include <Kernel/FileSystem/File.h>
|
||||
#include <Kernel/KResult.h>
|
||||
|
@ -35,10 +35,10 @@ public:
|
|||
bool can_accept() const { return !m_pending.is_empty(); }
|
||||
RefPtr<Socket> accept();
|
||||
bool is_connected() const { return m_connected; }
|
||||
KResult listen(int backlog);
|
||||
|
||||
virtual KResult bind(const sockaddr*, socklen_t) = 0;
|
||||
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0;
|
||||
virtual KResult listen(int) = 0;
|
||||
virtual bool get_local_address(sockaddr*, socklen_t*) = 0;
|
||||
virtual bool get_peer_address(sockaddr*, socklen_t*) = 0;
|
||||
virtual bool is_local() const { return false; }
|
||||
|
@ -73,6 +73,9 @@ protected:
|
|||
void load_receive_deadline();
|
||||
void load_send_deadline();
|
||||
|
||||
int backlog() const { return m_backlog; }
|
||||
void set_backlog(int backlog) { m_backlog = backlog; }
|
||||
|
||||
virtual const char* class_name() const override { return "Socket"; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,6 +39,7 @@ public:
|
|||
bool has_syn() const { return flags() & TCPFlags::SYN; }
|
||||
bool has_ack() const { return flags() & TCPFlags::ACK; }
|
||||
bool has_fin() const { return flags() & TCPFlags::FIN; }
|
||||
bool has_rst() const { return flags() & TCPFlags::RST; }
|
||||
|
||||
u8 data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; }
|
||||
void set_data_offset(u16 data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; }
|
||||
|
|
|
@ -1,28 +1,35 @@
|
|||
#include <Kernel/Devices/RandomDevice.h>
|
||||
#include <Kernel/FileSystem/FileDescription.h>
|
||||
#include <Kernel/Net/NetworkAdapter.h>
|
||||
#include <Kernel/Net/Routing.h>
|
||||
#include <Kernel/Net/TCP.h>
|
||||
#include <Kernel/Net/TCPSocket.h>
|
||||
#include <Kernel/FileSystem/FileDescription.h>
|
||||
#include <Kernel/Process.h>
|
||||
|
||||
//#define TCP_SOCKET_DEBUG
|
||||
|
||||
Lockable<HashMap<u16, TCPSocket*>>& TCPSocket::sockets_by_port()
|
||||
void TCPSocket::for_each(Function<void(TCPSocket*&)> callback)
|
||||
{
|
||||
static Lockable<HashMap<u16, TCPSocket*>>* s_map;
|
||||
LOCKER(sockets_by_tuple().lock());
|
||||
for (auto& it : sockets_by_tuple().resource())
|
||||
callback(it.value);
|
||||
}
|
||||
|
||||
Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple()
|
||||
{
|
||||
static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>* s_map;
|
||||
if (!s_map)
|
||||
s_map = new Lockable<HashMap<u16, TCPSocket*>>;
|
||||
s_map = new Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>;
|
||||
return *s_map;
|
||||
}
|
||||
|
||||
TCPSocketHandle TCPSocket::from_port(u16 port)
|
||||
TCPSocketHandle TCPSocket::from_tuple(const IPv4SocketTuple& tuple)
|
||||
{
|
||||
RefPtr<TCPSocket> socket;
|
||||
{
|
||||
LOCKER(sockets_by_port().lock());
|
||||
auto it = sockets_by_port().resource().find(port);
|
||||
if (it == sockets_by_port().resource().end())
|
||||
LOCKER(sockets_by_tuple().lock());
|
||||
auto it = sockets_by_tuple().resource().find(tuple);
|
||||
if (it == sockets_by_tuple().resource().end())
|
||||
return {};
|
||||
socket = (*it).value;
|
||||
ASSERT(socket);
|
||||
|
@ -30,6 +37,11 @@ TCPSocketHandle TCPSocket::from_port(u16 port)
|
|||
return { move(socket) };
|
||||
}
|
||||
|
||||
TCPSocketHandle TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port)
|
||||
{
|
||||
return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port));
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(int protocol)
|
||||
: IPv4Socket(SOCK_STREAM, protocol)
|
||||
{
|
||||
|
@ -37,8 +49,8 @@ TCPSocket::TCPSocket(int protocol)
|
|||
|
||||
TCPSocket::~TCPSocket()
|
||||
{
|
||||
LOCKER(sockets_by_port().lock());
|
||||
sockets_by_port().resource().remove(local_port());
|
||||
LOCKER(sockets_by_tuple().lock());
|
||||
sockets_by_tuple().resource().remove(tuple());
|
||||
}
|
||||
|
||||
NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol)
|
||||
|
@ -62,18 +74,13 @@ int TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size
|
|||
|
||||
int TCPSocket::protocol_send(const void* data, int data_length)
|
||||
{
|
||||
auto* adapter = adapter_for_route_to(peer_address());
|
||||
if (!adapter)
|
||||
return -EHOSTUNREACH;
|
||||
send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length);
|
||||
return data_length;
|
||||
}
|
||||
|
||||
void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size)
|
||||
{
|
||||
// FIXME: Maybe the socket should be bound to an adapter instead of looking it up every time?
|
||||
auto* adapter = adapter_for_route_to(peer_address());
|
||||
ASSERT(adapter);
|
||||
ASSERT(m_adapter);
|
||||
|
||||
auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size);
|
||||
auto& tcp_packet = *(TCPPacket*)(buffer.pointer());
|
||||
|
@ -95,19 +102,21 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size
|
|||
}
|
||||
|
||||
memcpy(tcp_packet.payload(), payload, payload_size);
|
||||
tcp_packet.set_checksum(compute_tcp_checksum(adapter->ipv4_address(), peer_address(), tcp_packet, payload_size));
|
||||
tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size));
|
||||
#ifdef TCP_SOCKET_DEBUG
|
||||
kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n",
|
||||
adapter->ipv4_address().to_string().characters(),
|
||||
kprintf("sending tcp packet from %s:%u to %s:%u with (%s%s%s%s) seq_no=%u, ack_no=%u\n",
|
||||
local_address().to_string().characters(),
|
||||
local_port(),
|
||||
peer_address().to_string().characters(),
|
||||
peer_port(),
|
||||
tcp_packet.has_syn() ? "SYN" : "",
|
||||
tcp_packet.has_ack() ? "ACK" : "",
|
||||
tcp_packet.has_fin() ? "FIN" : "",
|
||||
tcp_packet.has_rst() ? "RST" : "",
|
||||
tcp_packet.sequence_number(),
|
||||
tcp_packet.ack_number());
|
||||
#endif
|
||||
adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size());
|
||||
m_adapter->send_ipv4(MACAddress(), peer_address(), IPv4Protocol::TCP, buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size)
|
||||
|
@ -152,11 +161,36 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, c
|
|||
return ~(checksum & 0xffff);
|
||||
}
|
||||
|
||||
KResult TCPSocket::protocol_bind()
|
||||
{
|
||||
if (!m_adapter) {
|
||||
m_adapter = NetworkAdapter::from_ipv4_address(local_address());
|
||||
if (!m_adapter)
|
||||
return KResult(-EADDRNOTAVAIL);
|
||||
}
|
||||
|
||||
return KSuccess;
|
||||
}
|
||||
|
||||
KResult TCPSocket::protocol_listen()
|
||||
{
|
||||
LOCKER(sockets_by_tuple().lock());
|
||||
if (sockets_by_tuple().resource().contains(tuple()))
|
||||
return KResult(-EADDRINUSE);
|
||||
sockets_by_tuple().resource().set(tuple(), this);
|
||||
set_state(State::Listen);
|
||||
return KSuccess;
|
||||
}
|
||||
|
||||
KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block)
|
||||
{
|
||||
auto* adapter = adapter_for_route_to(peer_address());
|
||||
if (!adapter)
|
||||
return KResult(-EHOSTUNREACH);
|
||||
if (!m_adapter) {
|
||||
m_adapter = adapter_for_route_to(peer_address());
|
||||
if (!m_adapter)
|
||||
return KResult(-EHOSTUNREACH);
|
||||
|
||||
set_local_address(m_adapter->ipv4_address());
|
||||
}
|
||||
|
||||
allocate_local_port_if_needed();
|
||||
|
||||
|
@ -164,7 +198,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
|
|||
m_ack_number = 0;
|
||||
|
||||
send_tcp_packet(TCPFlags::SYN);
|
||||
m_state = State::Connecting;
|
||||
m_state = State::SynSent;
|
||||
|
||||
if (should_block == ShouldBlock::Yes) {
|
||||
if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal)
|
||||
|
@ -183,12 +217,14 @@ int TCPSocket::protocol_allocate_local_port()
|
|||
static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
|
||||
u16 first_scan_port = first_ephemeral_port + RandomDevice::random_value() % ephemeral_port_range_size;
|
||||
|
||||
LOCKER(sockets_by_port().lock());
|
||||
LOCKER(sockets_by_tuple().lock());
|
||||
for (u16 port = first_scan_port;;) {
|
||||
auto it = sockets_by_port().resource().find(port);
|
||||
if (it == sockets_by_port().resource().end()) {
|
||||
IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port());
|
||||
|
||||
auto it = sockets_by_tuple().resource().find(proposed_tuple);
|
||||
if (it == sockets_by_tuple().resource().end()) {
|
||||
set_local_port(port);
|
||||
sockets_by_port().resource().set(port, this);
|
||||
sockets_by_tuple().resource().set(proposed_tuple, this);
|
||||
return port;
|
||||
}
|
||||
++port;
|
||||
|
@ -202,14 +238,16 @@ int TCPSocket::protocol_allocate_local_port()
|
|||
|
||||
bool TCPSocket::protocol_is_disconnected() const
|
||||
{
|
||||
return m_state == State::Disconnecting || m_state == State::Disconnected;
|
||||
}
|
||||
|
||||
KResult TCPSocket::protocol_bind()
|
||||
{
|
||||
LOCKER(sockets_by_port().lock());
|
||||
if (sockets_by_port().resource().contains(local_port()))
|
||||
return KResult(-EADDRINUSE);
|
||||
sockets_by_port().resource().set(local_port(), this);
|
||||
return KSuccess;
|
||||
switch (m_state) {
|
||||
case State::Closed:
|
||||
case State::CloseWait:
|
||||
case State::LastAck:
|
||||
case State::FinWait1:
|
||||
case State::FinWait2:
|
||||
case State::Closing:
|
||||
case State::TimeWait:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,19 +1,58 @@
|
|||
#pragma once
|
||||
|
||||
#include <AK/Function.h>
|
||||
#include <Kernel/Net/IPv4Socket.h>
|
||||
|
||||
class TCPSocket final : public IPv4Socket {
|
||||
public:
|
||||
static void for_each(Function<void(TCPSocket*&)>);
|
||||
static NonnullRefPtr<TCPSocket> create(int protocol);
|
||||
virtual ~TCPSocket() override;
|
||||
|
||||
enum class State {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
Disconnecting,
|
||||
Closed,
|
||||
Listen,
|
||||
SynSent,
|
||||
SynReceived,
|
||||
Established,
|
||||
CloseWait,
|
||||
LastAck,
|
||||
FinWait1,
|
||||
FinWait2,
|
||||
Closing,
|
||||
TimeWait,
|
||||
};
|
||||
|
||||
static const char* to_string(State state)
|
||||
{
|
||||
switch (state) {
|
||||
case State::Closed:
|
||||
return "Closed";
|
||||
case State::Listen:
|
||||
return "Listen";
|
||||
case State::SynSent:
|
||||
return "SynSent";
|
||||
case State::SynReceived:
|
||||
return "SynReceived";
|
||||
case State::Established:
|
||||
return "Established";
|
||||
case State::CloseWait:
|
||||
return "CloseWait";
|
||||
case State::LastAck:
|
||||
return "LastAck";
|
||||
case State::FinWait1:
|
||||
return "FinWait1";
|
||||
case State::FinWait2:
|
||||
return "FinWait2";
|
||||
case State::Closing:
|
||||
return "Closing";
|
||||
case State::TimeWait:
|
||||
return "TimeWait";
|
||||
default:
|
||||
return "None";
|
||||
}
|
||||
}
|
||||
|
||||
State state() const { return m_state; }
|
||||
void set_state(State state) { m_state = state; }
|
||||
|
||||
|
@ -24,8 +63,9 @@ public:
|
|||
|
||||
void send_tcp_packet(u16 flags, const void* = nullptr, int = 0);
|
||||
|
||||
static Lockable<HashMap<u16, TCPSocket*>>& sockets_by_port();
|
||||
static TCPSocketHandle from_port(u16);
|
||||
static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple();
|
||||
static TCPSocketHandle from_tuple(const IPv4SocketTuple& tuple);
|
||||
static TCPSocketHandle from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
|
||||
|
||||
private:
|
||||
explicit TCPSocket(int protocol);
|
||||
|
@ -39,10 +79,12 @@ private:
|
|||
virtual int protocol_allocate_local_port() override;
|
||||
virtual bool protocol_is_disconnected() const override;
|
||||
virtual KResult protocol_bind() override;
|
||||
virtual KResult protocol_listen() override;
|
||||
|
||||
NetworkAdapter* m_adapter { nullptr };
|
||||
u32 m_sequence_number { 0 };
|
||||
u32 m_ack_number { 0 };
|
||||
State m_state { State::Disconnected };
|
||||
State m_state { State::Closed };
|
||||
};
|
||||
|
||||
class TCPSocketHandle : public SocketHandle {
|
||||
|
|
Loading…
Reference in a new issue