TCP: Start working on auto-closing connections when we get FIN.

This commit is contained in:
Andreas Kling 2019-03-14 15:23:32 +01:00
parent 4629272135
commit 25e521f510
6 changed files with 35 additions and 6 deletions

View file

@ -92,22 +92,24 @@ void IPv4Socket::detach_fd(SocketRole)
bool IPv4Socket::can_read(SocketRole) const
{
if (protocol_is_disconnected())
return true;
return m_can_read;
}
ssize_t IPv4Socket::read(SocketRole, byte*, ssize_t)
ssize_t IPv4Socket::read(SocketRole, byte* buffer, ssize_t size)
{
ASSERT_NOT_REACHED();
return recvfrom(buffer, size, 0, nullptr, 0);
}
ssize_t IPv4Socket::write(SocketRole, const byte*, ssize_t)
ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size)
{
ASSERT_NOT_REACHED();
return sendto(data, size, 0, nullptr, 0);
}
bool IPv4Socket::can_write(SocketRole) const
{
ASSERT_NOT_REACHED();
return true;
}
void IPv4Socket::allocate_source_port_if_needed()
@ -168,9 +170,17 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock
if (!m_receive_queue.is_empty()) {
packet_buffer = m_receive_queue.take_first();
m_can_read = !m_receive_queue.is_empty();
#ifdef IPV4_SOCKET_DEBUG
kprintf("IPv4Socket(%p): recvfrom without blocking %d bytes, packets in queue: %d\n", this, packet_buffer.size(), m_receive_queue.size_slow());
#endif
}
}
if (packet_buffer.is_null()) {
if (protocol_is_disconnected()) {
kprintf("IPv4Socket{%p} is protocol-disconnected, returning 0 in recvfrom!\n", this);
return 0;
}
current->set_blocked_socket(this);
load_receive_deadline();
block(Process::BlockedReceive);
@ -185,6 +195,9 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock
ASSERT(!m_receive_queue.is_empty());
packet_buffer = m_receive_queue.take_first();
m_can_read = !m_receive_queue.is_empty();
#ifdef IPV4_SOCKET_DEBUG
kprintf("IPv4Socket(%p): recvfrom with blocking %d bytes, packets in queue: %d\n", this, packet_buffer.size(), m_receive_queue.size_slow());
#endif
}
ASSERT(!packet_buffer.is_null());
auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer());

View file

@ -51,6 +51,7 @@ protected:
virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
virtual KResult protocol_connect() { return KSuccess; }
virtual void protocol_allocate_source_port() { }
virtual bool protocol_is_disconnected() const { return false; }
private:
virtual bool is_ipv4() const override { return true; }

View file

@ -287,7 +287,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
ASSERT(socket->source_port() == tcp_packet.destination_port());
if (tcp_packet.ack_number() != socket->sequence_number()) {
kprintf("handle_tcp: ack/seq mismatch: got %u, wanted %u\n",tcp_packet.ack_number(), socket->sequence_number());
kprintf("handle_tcp: ack/seq mismatch: got %u, wanted %u\n", tcp_packet.ack_number(), socket->sequence_number());
return;
}
@ -300,6 +300,14 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
return;
}
if (tcp_packet.has_fin()) {
kprintf("handle_tcp: Got FIN, payload_size=%u\n", payload_size);
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
socket->set_state(TCPSocket::State::Disconnecting);
return;
}
socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
tcp_packet.ack_number(),

View file

@ -37,6 +37,7 @@ public:
bool has_syn() const { return flags() & TCPFlags::SYN; }
bool has_ack() const { return flags() & TCPFlags::ACK; }
bool has_fin() const { return flags() & TCPFlags::FIN; }
byte data_offset() const { return (m_flags_and_data_offset & 0xf000) >> 12; }
void set_data_offset(word data_offset) { m_flags_and_data_offset = (m_flags_and_data_offset & ~0xf000) | data_offset << 12; }

View file

@ -186,3 +186,8 @@ void TCPSocket::protocol_allocate_source_port()
}
}
}
bool TCPSocket::protocol_is_disconnected() const
{
return m_state == State::Disconnecting || m_state == State::Disconnected;
}

View file

@ -36,6 +36,7 @@ private:
virtual int protocol_send(const void*, int) override;
virtual KResult protocol_connect() override;
virtual void protocol_allocate_source_port() override;
virtual bool protocol_is_disconnected() const override;
dword m_sequence_number { 0 };
dword m_ack_number { 0 };