ktls: deep copy tls_enable struct for in-kernel tcp consumers

Doing a deep copy of the keys early allows users of the
tls_enable structure to assume kernel memory.
This enables the socket options to be set by kernel threads.

Reviewed By:	#transport, tuexen, jhb, rrs
Sponsored by:	NetApp, Inc.
X-NetApp-PR:	#79
Differential Revision:	https://reviews.freebsd.org/D44250
This commit is contained in:
Richard Scheffenegger 2024-03-13 12:35:51 +01:00
parent bf8a3a816d
commit 85df11a1de
3 changed files with 97 additions and 60 deletions

View file

@ -297,10 +297,86 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_toe, OID_AUTO, chacha20, CTLFLAG_RD,
static MALLOC_DEFINE(M_KTLS, "ktls", "Kernel TLS");
static void ktls_reclaim_thread(void *ctx);
static void ktls_reset_receive_tag(void *context, int pending);
static void ktls_reset_send_tag(void *context, int pending);
static void ktls_work_thread(void *ctx);
static void ktls_reclaim_thread(void *ctx);
int
ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
{
struct tls_enable_v0 tls_v0;
int error;
uint8_t *cipher_key = NULL, *iv = NULL, *auth_key = NULL;
if (sopt->sopt_valsize == sizeof(tls_v0)) {
error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0), sizeof(tls_v0));
if (error != 0)
goto done;
memset(tls, 0, sizeof(*tls));
tls->cipher_key = tls_v0.cipher_key;
tls->iv = tls_v0.iv;
tls->auth_key = tls_v0.auth_key;
tls->cipher_algorithm = tls_v0.cipher_algorithm;
tls->cipher_key_len = tls_v0.cipher_key_len;
tls->iv_len = tls_v0.iv_len;
tls->auth_algorithm = tls_v0.auth_algorithm;
tls->auth_key_len = tls_v0.auth_key_len;
tls->flags = tls_v0.flags;
tls->tls_vmajor = tls_v0.tls_vmajor;
tls->tls_vminor = tls_v0.tls_vminor;
} else
error = sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls));
if (error != 0)
goto done;
/*
* Now do a deep copy of the variable-length arrays in the struct, so that
* subsequent consumers of it can reliably assume kernel memory. This
* requires doing our own allocations, which we will free in the
* error paths so that our caller need only worry about outstanding
* allocations existing on successful return.
*/
cipher_key = malloc(tls->cipher_key_len, M_KTLS, M_WAITOK);
iv = malloc(tls->iv_len, M_KTLS, M_WAITOK);
auth_key = malloc(tls->auth_key_len, M_KTLS, M_WAITOK);
if (sopt->sopt_td != NULL) {
error = copyin(tls->cipher_key, cipher_key, tls->cipher_key_len);
if (error != 0)
goto done;
error = copyin(tls->iv, iv, tls->iv_len);
if (error != 0)
goto done;
error = copyin(tls->auth_key, auth_key, tls->auth_key_len);
if (error != 0)
goto done;
} else {
bcopy(tls->cipher_key, cipher_key, tls->cipher_key_len);
bcopy(tls->iv, iv, tls->iv_len);
bcopy(tls->auth_key, auth_key, tls->auth_key_len);
}
tls->cipher_key = cipher_key;
tls->iv = iv;
tls->auth_key = auth_key;
done:
if (error != 0) {
zfree(cipher_key, M_KTLS);
zfree(iv, M_KTLS);
zfree(auth_key, M_KTLS);
}
return (error);
}
void
ktls_cleanup_tls_enable(struct tls_enable *tls)
{
zfree(__DECONST(void *, tls->cipher_key), M_KTLS);
zfree(__DECONST(void *, tls->iv), M_KTLS);
zfree(__DECONST(void *, tls->auth_key), M_KTLS);
}
static u_int
ktls_get_cpu(struct socket *so)
@ -702,18 +778,12 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
tls->params.auth_key_len = en->auth_key_len;
tls->params.auth_key = malloc(en->auth_key_len, M_KTLS,
M_WAITOK);
error = copyin(en->auth_key, tls->params.auth_key,
en->auth_key_len);
if (error)
goto out;
bcopy(en->auth_key, tls->params.auth_key, en->auth_key_len);
}
tls->params.cipher_key_len = en->cipher_key_len;
tls->params.cipher_key = malloc(en->cipher_key_len, M_KTLS, M_WAITOK);
error = copyin(en->cipher_key, tls->params.cipher_key,
en->cipher_key_len);
if (error)
goto out;
bcopy(en->cipher_key, tls->params.cipher_key, en->cipher_key_len);
/*
* This holds the implicit portion of the nonce for AEAD
@ -722,9 +792,7 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
*/
if (en->iv_len != 0) {
tls->params.iv_len = en->iv_len;
error = copyin(en->iv, tls->params.iv, en->iv_len);
if (error)
goto out;
bcopy(en->iv, tls->params.iv, en->iv_len);
/*
* For TLS 1.2 with GCM, generate an 8-byte nonce as a
@ -740,10 +808,6 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
*tlsp = tls;
return (0);
out:
ktls_free(tls);
return (error);
}
static struct ktls_session *

View file

@ -1914,37 +1914,6 @@ CTASSERT(TCP_CA_NAME_MAX <= TCP_LOG_ID_LEN);
CTASSERT(TCP_LOG_REASON_LEN <= TCP_LOG_ID_LEN);
#endif
#ifdef KERN_TLS
static int
copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
{
struct tls_enable_v0 tls_v0;
int error;
if (sopt->sopt_valsize == sizeof(tls_v0)) {
error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0),
sizeof(tls_v0));
if (error)
return (error);
memset(tls, 0, sizeof(*tls));
tls->cipher_key = tls_v0.cipher_key;
tls->iv = tls_v0.iv;
tls->auth_key = tls_v0.auth_key;
tls->cipher_algorithm = tls_v0.cipher_algorithm;
tls->cipher_key_len = tls_v0.cipher_key_len;
tls->iv_len = tls_v0.iv_len;
tls->auth_algorithm = tls_v0.auth_algorithm;
tls->auth_key_len = tls_v0.auth_key_len;
tls->flags = tls_v0.flags;
tls->tls_vmajor = tls_v0.tls_vmajor;
tls->tls_vminor = tls_v0.tls_vminor;
return (0);
}
return (sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls)));
}
#endif
extern struct cc_algo newreno_cc_algo;
static int
@ -2292,15 +2261,16 @@ tcp_default_ctloutput(struct tcpcb *tp, struct sockopt *sopt)
#ifdef KERN_TLS
case TCP_TXTLS_ENABLE:
INP_WUNLOCK(inp);
error = copyin_tls_enable(sopt, &tls);
if (error)
error = ktls_copyin_tls_enable(sopt, &tls);
if (error != 0)
break;
error = ktls_enable_tx(so, &tls);
ktls_cleanup_tls_enable(&tls);
break;
case TCP_TXTLS_MODE:
INP_WUNLOCK(inp);
error = sooptcopyin(sopt, &ui, sizeof(ui), sizeof(ui));
if (error)
if (error != 0)
return (error);
INP_WLOCK_RECHECK(inp);
@ -2309,11 +2279,11 @@ tcp_default_ctloutput(struct tcpcb *tp, struct sockopt *sopt)
break;
case TCP_RXTLS_ENABLE:
INP_WUNLOCK(inp);
error = sooptcopyin(sopt, &tls, sizeof(tls),
sizeof(tls));
if (error)
error = ktls_copyin_tls_enable(sopt, &tls);
if (error != 0)
break;
error = ktls_enable_rx(so, &tls);
ktls_cleanup_tls_enable(&tls);
break;
#endif
case TCP_MAXUNACKTIME:

View file

@ -174,6 +174,7 @@ struct m_snd_tag;
struct mbuf;
struct sockbuf;
struct socket;
struct sockopt;
struct ktls_session {
struct ktls_ocf_session *ocf_session;
@ -213,27 +214,29 @@ typedef enum {
} ktls_mbuf_crypto_st_t;
void ktls_check_rx(struct sockbuf *sb);
ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len);
void ktls_cleanup_tls_enable(struct tls_enable *tls);
int ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls);
void ktls_disable_ifnet(void *arg);
int ktls_enable_rx(struct socket *so, struct tls_enable *en);
int ktls_enable_tx(struct socket *so, struct tls_enable *en);
void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
void ktls_enqueue_to_free(struct mbuf *m);
void ktls_destroy(struct ktls_session *tls);
void ktls_frame(struct mbuf *m, struct ktls_session *tls, int *enqueue_cnt,
uint8_t record_type);
bool ktls_permit_empty_frames(struct ktls_session *tls);
void ktls_seq(struct sockbuf *sb, struct mbuf *m);
void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
void ktls_enqueue_to_free(struct mbuf *m);
int ktls_get_rx_mode(struct socket *so, int *modep);
int ktls_set_tx_mode(struct socket *so, int mode);
int ktls_get_tx_mode(struct socket *so, int *modep);
int ktls_get_rx_sequence(struct inpcb *inp, uint32_t *tcpseq, uint64_t *tlsseq);
void ktls_input_ifp_mismatch(struct sockbuf *sb, struct ifnet *ifp);
int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len);
#ifdef RATELIMIT
int ktls_modify_txrtlmt(struct ktls_session *tls, uint64_t max_pacing_rate);
#endif
int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
bool ktls_pending_rx_info(struct sockbuf *sb, uint64_t *seqnop, size_t *residp);
bool ktls_permit_empty_frames(struct ktls_session *tls);
void ktls_seq(struct sockbuf *sb, struct mbuf *m);
int ktls_set_tx_mode(struct socket *so, int mode);
static inline struct ktls_session *
ktls_hold(struct ktls_session *tls)