Merge pull request #27254 from poettering/cmsg-align-check

socket-util: tighten CMSG_TYPED_DATA() alignment checks
This commit is contained in:
Yu Watanabe 2023-04-14 13:49:04 +09:00 committed by GitHub
commit 13524b29a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 39 additions and 38 deletions

View file

@ -1047,7 +1047,7 @@ ssize_t receive_one_fd_iov(
} }
if (found) if (found)
*ret_fd = *(int*) CMSG_DATA(found); *ret_fd = *CMSG_TYPED_DATA(found, int);
else else
*ret_fd = -EBADF; *ret_fd = -EBADF;

View file

@ -175,9 +175,16 @@ int flush_accept(int fd);
#define CMSG_FOREACH(cmsg, mh) \ #define CMSG_FOREACH(cmsg, mh) \
for ((cmsg) = CMSG_FIRSTHDR(mh); (cmsg); (cmsg) = CMSG_NXTHDR((mh), (cmsg))) for ((cmsg) = CMSG_FIRSTHDR(mh); (cmsg); (cmsg) = CMSG_NXTHDR((mh), (cmsg)))
/* Returns the cmsghdr's data pointer, but safely cast to the specified type. Does two alignment checks: one
* at compile time, that the requested type has a smaller or same alignment as 'struct cmsghdr', and one
* during runtime, that the actual pointer matches the alignment too. This is supposed to catch cases such as
* 'struct timeval' is embedded into 'struct cmsghdr' on architectures where the alignment of the former is 8
* bytes (because of a 64bit time_t), but of the latter is 4 bytes (because size_t is 32bit), such as
* riscv32. */
#define CMSG_TYPED_DATA(cmsg, type) \ #define CMSG_TYPED_DATA(cmsg, type) \
({ \ ({ \
struct cmsghdr *_cmsg = cmsg; \ struct cmsghdr *_cmsg = cmsg; \
assert_cc(__alignof__(type) <= __alignof__(struct cmsghdr)); \
_cmsg ? CAST_ALIGN_PTR(type, CMSG_DATA(_cmsg)) : (type*) NULL; \ _cmsg ? CAST_ALIGN_PTR(type, CMSG_DATA(_cmsg)) : (type*) NULL; \
}) })

View file

@ -2629,7 +2629,7 @@ static int manager_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
assert(!fd_array); assert(!fd_array);
fd_array = (int*) CMSG_DATA(cmsg); fd_array = CMSG_TYPED_DATA(cmsg, int);
n_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); n_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
} else if (cmsg->cmsg_level == SOL_SOCKET && } else if (cmsg->cmsg_level == SOL_SOCKET &&
@ -2637,7 +2637,7 @@ static int manager_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t
cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) { cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
assert(!ucred); assert(!ucred);
ucred = (struct ucred*) CMSG_DATA(cmsg); ucred = CMSG_TYPED_DATA(cmsg, struct ucred);
} }
} }

View file

@ -1094,7 +1094,7 @@ static int process_socket(int fd) {
} }
assert(input_fd < 0); assert(input_fd < 0);
input_fd = *(int*) CMSG_DATA(found); input_fd = *CMSG_TYPED_DATA(found, int);
break; break;
} else } else
cmsg_close_all(&mh); cmsg_close_all(&mh);

View file

@ -1096,7 +1096,7 @@ static ssize_t read_datagram(
cmsg->cmsg_type == SCM_CREDENTIALS && cmsg->cmsg_type == SCM_CREDENTIALS &&
cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) { cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
assert(!sender); assert(!sender);
sender = (struct ucred*) CMSG_DATA(cmsg); sender = CMSG_TYPED_DATA(cmsg, struct ucred);
} }
if (cmsg->cmsg_level == SOL_SOCKET && if (cmsg->cmsg_level == SOL_SOCKET &&
@ -1108,7 +1108,7 @@ static ssize_t read_datagram(
} }
assert(passed_fd < 0); assert(passed_fd < 0);
passed_fd = *(int*) CMSG_DATA(cmsg); passed_fd = *CMSG_TYPED_DATA(cmsg, int);
} }
} }

View file

@ -1486,21 +1486,21 @@ int server_process_datagram(
cmsg->cmsg_type == SCM_CREDENTIALS && cmsg->cmsg_type == SCM_CREDENTIALS &&
cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) { cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
assert(!ucred); assert(!ucred);
ucred = (struct ucred*) CMSG_DATA(cmsg); ucred = CMSG_TYPED_DATA(cmsg, struct ucred);
} else if (cmsg->cmsg_level == SOL_SOCKET && } else if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_SECURITY) { cmsg->cmsg_type == SCM_SECURITY) {
assert(!label); assert(!label);
label = (char*) CMSG_DATA(cmsg); label = CMSG_TYPED_DATA(cmsg, char);
label_len = cmsg->cmsg_len - CMSG_LEN(0); label_len = cmsg->cmsg_len - CMSG_LEN(0);
} else if (cmsg->cmsg_level == SOL_SOCKET && } else if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SO_TIMESTAMP && cmsg->cmsg_type == SO_TIMESTAMP &&
cmsg->cmsg_len == CMSG_LEN(sizeof(struct timeval))) { cmsg->cmsg_len == CMSG_LEN(sizeof(struct timeval))) {
assert(!tv); assert(!tv);
tv = (struct timeval*) CMSG_DATA(cmsg); tv = CMSG_TYPED_DATA(cmsg, struct timeval);
} else if (cmsg->cmsg_level == SOL_SOCKET && } else if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_RIGHTS) { cmsg->cmsg_type == SCM_RIGHTS) {
assert(!fds); assert(!fds);
fds = (int*) CMSG_DATA(cmsg); fds = CMSG_TYPED_DATA(cmsg, int);
n_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); n_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
} }

View file

@ -192,7 +192,7 @@ int icmp6_receive(int fd, void *buffer, size_t size, struct in6_addr *ret_dst,
if (cmsg->cmsg_level == SOL_IPV6 && if (cmsg->cmsg_level == SOL_IPV6 &&
cmsg->cmsg_type == IPV6_HOPLIMIT && cmsg->cmsg_type == IPV6_HOPLIMIT &&
cmsg->cmsg_len == CMSG_LEN(sizeof(int))) { cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
int hops = *(int*) CMSG_DATA(cmsg); int hops = *CMSG_TYPED_DATA(cmsg, int);
if (hops != 255) if (hops != 255)
return -EMULTIHOP; return -EMULTIHOP;
@ -201,7 +201,7 @@ int icmp6_receive(int fd, void *buffer, size_t size, struct in6_addr *ret_dst,
if (cmsg->cmsg_level == SOL_SOCKET && if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SO_TIMESTAMP && cmsg->cmsg_type == SO_TIMESTAMP &&
cmsg->cmsg_len == CMSG_LEN(sizeof(struct timeval))) cmsg->cmsg_len == CMSG_LEN(sizeof(struct timeval)))
triple_timestamp_from_realtime(&t, timeval_load((struct timeval*) CMSG_DATA(cmsg))); triple_timestamp_from_realtime(&t, timeval_load(CMSG_TYPED_DATA(cmsg, struct timeval)));
} }
if (!triple_timestamp_is_set(&t)) if (!triple_timestamp_is_set(&t))

View file

@ -1981,7 +1981,7 @@ static int client_receive_message_raw(
cmsg = cmsg_find(&msg, SOL_PACKET, PACKET_AUXDATA, CMSG_LEN(sizeof(struct tpacket_auxdata))); cmsg = cmsg_find(&msg, SOL_PACKET, PACKET_AUXDATA, CMSG_LEN(sizeof(struct tpacket_auxdata)));
if (cmsg) { if (cmsg) {
struct tpacket_auxdata *aux = (struct tpacket_auxdata*) CMSG_DATA(cmsg); struct tpacket_auxdata *aux = CMSG_TYPED_DATA(cmsg, struct tpacket_auxdata);
checksum = !(aux->tp_status & TP_STATUS_CSUMNOTREADY); checksum = !(aux->tp_status & TP_STATUS_CSUMNOTREADY);
} }

View file

@ -1310,7 +1310,7 @@ static int server_receive_message(sd_event_source *s, int fd,
if (cmsg->cmsg_level == IPPROTO_IP && if (cmsg->cmsg_level == IPPROTO_IP &&
cmsg->cmsg_type == IP_PKTINFO && cmsg->cmsg_type == IP_PKTINFO &&
cmsg->cmsg_len == CMSG_LEN(sizeof(struct in_pktinfo))) { cmsg->cmsg_len == CMSG_LEN(sizeof(struct in_pktinfo))) {
struct in_pktinfo *info = (struct in_pktinfo*)CMSG_DATA(cmsg); struct in_pktinfo *info = CMSG_TYPED_DATA(cmsg, struct in_pktinfo);
/* TODO figure out if this can be done as a filter on /* TODO figure out if this can be done as a filter on
* the socket, like for IPv6 */ * the socket, like for IPv6 */

View file

@ -604,7 +604,7 @@ static int bus_socket_read_auth(sd_bus *b) {
* protocol? Somebody is playing games with * protocol? Somebody is playing games with
* us. Close them all, and fail */ * us. Close them all, and fail */
j = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); j = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
close_many((int*) CMSG_DATA(cmsg), j); close_many(CMSG_TYPED_DATA(cmsg, int), j);
return -EIO; return -EIO;
} else } else
log_debug("Got unexpected auxiliary data with level=%d and type=%d", log_debug("Got unexpected auxiliary data with level=%d and type=%d",
@ -1268,18 +1268,18 @@ int bus_socket_read_message(sd_bus *bus) {
* isn't actually enabled? Close them, * isn't actually enabled? Close them,
* and fail */ * and fail */
close_many((int*) CMSG_DATA(cmsg), n); close_many(CMSG_TYPED_DATA(cmsg, int), n);
return -EIO; return -EIO;
} }
f = reallocarray(bus->fds, bus->n_fds + n, sizeof(int)); f = reallocarray(bus->fds, bus->n_fds + n, sizeof(int));
if (!f) { if (!f) {
close_many((int*) CMSG_DATA(cmsg), n); close_many(CMSG_TYPED_DATA(cmsg, int), n);
return -ENOMEM; return -ENOMEM;
} }
for (i = 0; i < n; i++) for (i = 0; i < n; i++)
f[bus->n_fds++] = fd_move_above_stdio(((int*) CMSG_DATA(cmsg))[i]); f[bus->n_fds++] = fd_move_above_stdio(CMSG_TYPED_DATA(cmsg, int)[i]);
bus->fds = f; bus->fds = f;
} else } else
log_debug("Got unexpected auxiliary data with level=%d and type=%d", log_debug("Got unexpected auxiliary data with level=%d and type=%d",

View file

@ -503,7 +503,6 @@ int device_monitor_receive_device(sd_device_monitor *m, sd_device **ret) {
.msg_name = &snl, .msg_name = &snl,
.msg_namelen = sizeof(snl), .msg_namelen = sizeof(snl),
}; };
struct cmsghdr *cmsg;
struct ucred *cred; struct ucred *cred;
size_t offset; size_t offset;
ssize_t n; ssize_t n;
@ -559,12 +558,11 @@ int device_monitor_receive_device(sd_device_monitor *m, sd_device **ret) {
snl.nl.nl_pid); snl.nl.nl_pid);
} }
cmsg = CMSG_FIRSTHDR(&smsg); cred = CMSG_FIND_DATA(&smsg, SOL_SOCKET, SCM_CREDENTIALS, struct ucred);
if (!cmsg || cmsg->cmsg_type != SCM_CREDENTIALS) if (!cred)
return log_monitor_errno(m, SYNTHETIC_ERRNO(EAGAIN), return log_monitor_errno(m, SYNTHETIC_ERRNO(EAGAIN),
"No sender credentials received, ignoring message."); "No sender credentials received, ignoring message.");
cred = (struct ucred*) CMSG_DATA(cmsg);
if (!check_sender_uid(m, cred->uid)) if (!check_sender_uid(m, cred->uid))
return log_monitor_errno(m, SYNTHETIC_ERRNO(EAGAIN), return log_monitor_errno(m, SYNTHETIC_ERRNO(EAGAIN),
"Sender uid="UID_FMT", message ignored.", cred->uid); "Sender uid="UID_FMT", message ignored.", cred->uid);

View file

@ -147,7 +147,7 @@ static int dns_stream_identify(DnsStream *s) {
switch (cmsg->cmsg_type) { switch (cmsg->cmsg_type) {
case IPV6_PKTINFO: { case IPV6_PKTINFO: {
struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg); struct in6_pktinfo *i = CMSG_TYPED_DATA(cmsg, struct in6_pktinfo);
if (s->ifindex <= 0) if (s->ifindex <= 0)
s->ifindex = i->ipi6_ifindex; s->ifindex = i->ipi6_ifindex;
@ -155,7 +155,7 @@ static int dns_stream_identify(DnsStream *s) {
} }
case IPV6_HOPLIMIT: case IPV6_HOPLIMIT:
s->ttl = *(int *) CMSG_DATA(cmsg); s->ttl = *CMSG_TYPED_DATA(cmsg, int);
break; break;
} }
@ -165,7 +165,7 @@ static int dns_stream_identify(DnsStream *s) {
switch (cmsg->cmsg_type) { switch (cmsg->cmsg_type) {
case IP_PKTINFO: { case IP_PKTINFO: {
struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg); struct in_pktinfo *i = CMSG_TYPED_DATA(cmsg, struct in_pktinfo);
if (s->ifindex <= 0) if (s->ifindex <= 0)
s->ifindex = i->ipi_ifindex; s->ifindex = i->ipi_ifindex;
@ -173,7 +173,7 @@ static int dns_stream_identify(DnsStream *s) {
} }
case IP_TTL: case IP_TTL:
s->ttl = *(int *) CMSG_DATA(cmsg); s->ttl = *CMSG_TYPED_DATA(cmsg, int);
break; break;
} }
} }

View file

@ -834,7 +834,7 @@ int manager_recv(Manager *m, int fd, DnsProtocol protocol, DnsPacket **ret) {
switch (cmsg->cmsg_type) { switch (cmsg->cmsg_type) {
case IPV6_PKTINFO: { case IPV6_PKTINFO: {
struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg); struct in6_pktinfo *i = CMSG_TYPED_DATA(cmsg, struct in6_pktinfo);
if (p->ifindex <= 0) if (p->ifindex <= 0)
p->ifindex = i->ipi6_ifindex; p->ifindex = i->ipi6_ifindex;
@ -844,11 +844,11 @@ int manager_recv(Manager *m, int fd, DnsProtocol protocol, DnsPacket **ret) {
} }
case IPV6_HOPLIMIT: case IPV6_HOPLIMIT:
p->ttl = *(int *) CMSG_DATA(cmsg); p->ttl = *CMSG_TYPED_DATA(cmsg, int);
break; break;
case IPV6_RECVFRAGSIZE: case IPV6_RECVFRAGSIZE:
p->fragsize = *(int *) CMSG_DATA(cmsg); p->fragsize = *CMSG_TYPED_DATA(cmsg, int);
break; break;
} }
} else if (cmsg->cmsg_level == IPPROTO_IP) { } else if (cmsg->cmsg_level == IPPROTO_IP) {
@ -857,7 +857,7 @@ int manager_recv(Manager *m, int fd, DnsProtocol protocol, DnsPacket **ret) {
switch (cmsg->cmsg_type) { switch (cmsg->cmsg_type) {
case IP_PKTINFO: { case IP_PKTINFO: {
struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg); struct in_pktinfo *i = CMSG_TYPED_DATA(cmsg, struct in_pktinfo);
if (p->ifindex <= 0) if (p->ifindex <= 0)
p->ifindex = i->ipi_ifindex; p->ifindex = i->ipi_ifindex;
@ -867,11 +867,11 @@ int manager_recv(Manager *m, int fd, DnsProtocol protocol, DnsPacket **ret) {
} }
case IP_TTL: case IP_TTL:
p->ttl = *(int *) CMSG_DATA(cmsg); p->ttl = *CMSG_TYPED_DATA(cmsg, int);
break; break;
case IP_RECVFRAGSIZE: case IP_RECVFRAGSIZE:
p->fragsize = *(int *) CMSG_DATA(cmsg); p->fragsize = *CMSG_TYPED_DATA(cmsg, int);
break; break;
} }
} }

View file

@ -161,7 +161,6 @@ static int udev_ctrl_connection_event_handler(sd_event_source *s, int fd, uint32
.msg_control = &control, .msg_control = &control,
.msg_controllen = sizeof(control), .msg_controllen = sizeof(control),
}; };
struct cmsghdr *cmsg;
struct ucred *cred; struct ucred *cred;
ssize_t size; ssize_t size;
@ -185,15 +184,12 @@ static int udev_ctrl_connection_event_handler(sd_event_source *s, int fd, uint32
cmsg_close_all(&smsg); cmsg_close_all(&smsg);
cmsg = CMSG_FIRSTHDR(&smsg); cred = CMSG_FIND_DATA(&smsg, SOL_SOCKET, SCM_CREDENTIALS, struct ucred);
if (!cred) {
if (!cmsg || cmsg->cmsg_type != SCM_CREDENTIALS) {
log_error("No sender credentials received, ignoring message"); log_error("No sender credentials received, ignoring message");
return 0; return 0;
} }
cred = (struct ucred *) CMSG_DATA(cmsg);
if (cred->uid != 0) { if (cred->uid != 0) {
log_error("Invalid sender uid "UID_FMT", ignoring message", cred->uid); log_error("Invalid sender uid "UID_FMT", ignoring message", cred->uid);
return 0; return 0;