diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h index 2f8b39a614c3..dbeb75c29857 100644 --- a/net/rxrpc/ar-internal.h +++ b/net/rxrpc/ar-internal.h @@ -1079,6 +1079,7 @@ void rxrpc_send_version_request(struct rxrpc_local *local, /* * local_object.c */ +void rxrpc_local_dont_fragment(const struct rxrpc_local *local, bool set); struct rxrpc_local *rxrpc_lookup_local(struct net *, const struct sockaddr_rxrpc *); struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *, enum rxrpc_local_trace); struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *, enum rxrpc_local_trace); diff --git a/net/rxrpc/local_object.c b/net/rxrpc/local_object.c index c553a30e9c83..34d307368135 100644 --- a/net/rxrpc/local_object.c +++ b/net/rxrpc/local_object.c @@ -36,6 +36,17 @@ static void rxrpc_encap_err_rcv(struct sock *sk, struct sk_buff *skb, int err, return ipv6_icmp_error(sk, skb, err, port, info, payload); } +/* + * Set or clear the Don't Fragment flag on a socket. + */ +void rxrpc_local_dont_fragment(const struct rxrpc_local *local, bool set) +{ + if (set) + ip_sock_set_mtu_discover(local->socket->sk, IP_PMTUDISC_DO); + else + ip_sock_set_mtu_discover(local->socket->sk, IP_PMTUDISC_DONT); +} + /* * Compare a local to an address. Return -ve, 0 or +ve to indicate less than, * same or greater than. @@ -203,7 +214,7 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) ip_sock_set_recverr(usk); /* we want to set the don't fragment bit */ - ip_sock_set_mtu_discover(usk, IP_PMTUDISC_DO); + rxrpc_local_dont_fragment(local, true); /* We want receive timestamps. */ sock_enable_timestamps(usk); diff --git a/net/rxrpc/output.c b/net/rxrpc/output.c index 5e53429c6922..a0906145e829 100644 --- a/net/rxrpc/output.c +++ b/net/rxrpc/output.c @@ -494,14 +494,12 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) switch (conn->local->srx.transport.family) { case AF_INET6: case AF_INET: - ip_sock_set_mtu_discover(conn->local->socket->sk, - IP_PMTUDISC_DONT); + rxrpc_local_dont_fragment(conn->local, false); rxrpc_inc_stat(call->rxnet, stat_tx_data_send_frag); ret = do_udp_sendmsg(conn->local->socket, &msg, len); conn->peer->last_tx_at = ktime_get_seconds(); - ip_sock_set_mtu_discover(conn->local->socket->sk, - IP_PMTUDISC_DO); + rxrpc_local_dont_fragment(conn->local, true); break; default: diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c index 1bf571a66e02..b52dedcebce0 100644 --- a/net/rxrpc/rxkad.c +++ b/net/rxrpc/rxkad.c @@ -724,7 +724,9 @@ static int rxkad_send_response(struct rxrpc_connection *conn, serial = atomic_inc_return(&conn->serial); whdr.serial = htonl(serial); + rxrpc_local_dont_fragment(conn->local, false); ret = kernel_sendmsg(conn->local->socket, &msg, iov, 3, len); + rxrpc_local_dont_fragment(conn->local, true); if (ret < 0) { trace_rxrpc_tx_fail(conn->debug_id, serial, ret, rxrpc_tx_point_rxkad_response);