fd-util: Expose helper to pack fds into 3,4,5,...

This is useful for situations where an array of FDs is to be passed into
a child process (i.e. by passing it through safe_fork). This function
can be called in the child (before calling exec) to pack the FDs to all
be next to each-other starting from SD_LISTEN_FDS_START (i.e. 3)
This commit is contained in:
Adrian Vovk 2024-02-13 15:09:54 -05:00 committed by Luca Boccassi
parent 034569150f
commit 85f660d46b
6 changed files with 82 additions and 48 deletions

View file

@ -464,6 +464,53 @@ int close_all_fds(const int except[], size_t n_except) {
return r;
}
int pack_fds(int fds[], size_t n_fds) {
if (n_fds <= 0)
return 0;
/* Shifts around the fds in the provided array such that they
* all end up packed next to each-other, in order, starting
* from SD_LISTEN_FDS_START. This must be called after close_all_fds();
* it is likely to freeze up otherwise. You should probably use safe_fork_full
* with FORK_CLOSE_ALL_FDS|FORK_PACK_FDS set, to ensure that this is done correctly.
* The fds array is modified in place with the new FD numbers. */
assert(fds);
for (int start = 0;;) {
int restart_from = -1;
for (int i = start; i < (int) n_fds; i++) {
int nfd;
/* Already at right index? */
if (fds[i] == i + 3)
continue;
nfd = fcntl(fds[i], F_DUPFD, i + 3);
if (nfd < 0)
return -errno;
safe_close(fds[i]);
fds[i] = nfd;
/* Hmm, the fd we wanted isn't free? Then
* let's remember that and try again from here */
if (nfd != i + 3 && restart_from < 0)
restart_from = i;
}
if (restart_from < 0)
break;
start = restart_from;
}
assert(fds[0] == 3);
return 0;
}
int same_fd(int a, int b) {
struct stat sta, stb;
pid_t pid;

View file

@ -77,6 +77,8 @@ int get_max_fd(void);
int close_all_fds(const int except[], size_t n_except);
int close_all_fds_without_malloc(const int except[], size_t n_except);
int pack_fds(int fds[], size_t n);
int same_fd(int a, int b);
void cmsg_close_all(struct msghdr *mh);

View file

@ -1468,7 +1468,7 @@ static int fork_flags_to_signal(ForkFlags flags) {
int safe_fork_full(
const char *name,
const int stdio_fds[3],
const int except_fds[],
int except_fds[],
size_t n_except_fds,
ForkFlags flags,
pid_t *ret_pid) {
@ -1697,6 +1697,19 @@ int safe_fork_full(
}
}
if (flags & FORK_PACK_FDS) {
/* FORK_CLOSE_ALL_FDS ensures that except_fds are the only FDs >= 3 that are
* open, this is including the log. This is required by pack_fds, which will
* get stuck in an infinite loop of any FDs other than except_fds are open. */
assert(FLAGS_SET(flags, FORK_CLOSE_ALL_FDS));
r = pack_fds(except_fds, n_except_fds);
if (r < 0) {
log_full_errno(prio, r, "Failed to pack file descriptors: %m");
_exit(EXIT_FAILURE);
}
}
if (flags & FORK_CLOEXEC_OFF) {
r = fd_cloexec_many(except_fds, n_except_fds, false);
if (r < 0) {
@ -1736,7 +1749,7 @@ int safe_fork_full(
int pidref_safe_fork_full(
const char *name,
const int stdio_fds[3],
const int except_fds[],
int except_fds[],
size_t n_except_fds,
ForkFlags flags,
PidRef *ret_pid) {
@ -1760,7 +1773,7 @@ int pidref_safe_fork_full(
int namespace_fork(
const char *outer_name,
const char *inner_name,
const int except_fds[],
int except_fds[],
size_t n_except_fds,
ForkFlags flags,
int pidns_fd,

View file

@ -185,12 +185,13 @@ typedef enum ForkFlags {
FORK_KEEP_NOTIFY_SOCKET = 1 << 17, /* Unless this specified, $NOTIFY_SOCKET will be unset. */
FORK_DETACH = 1 << 18, /* Double fork if needed to ensure PID1/subreaper is parent */
FORK_NEW_NETNS = 1 << 19, /* Run child in its own network namespace 💣 DO NOT USE IN THREADED PROGRAMS! 💣 */
FORK_PACK_FDS = 1 << 20, /* Rearrange the passed FDs to be FD 3,4,5,etc. Updates the array in place (combine with FORK_CLOSE_ALL_FDS!) */
} ForkFlags;
int safe_fork_full(
const char *name,
const int stdio_fds[3],
const int except_fds[],
int except_fds[],
size_t n_except_fds,
ForkFlags flags,
pid_t *ret_pid);
@ -202,7 +203,7 @@ static inline int safe_fork(const char *name, ForkFlags flags, pid_t *ret_pid) {
int pidref_safe_fork_full(
const char *name,
const int stdio_fds[3],
const int except_fds[],
int except_fds[],
size_t n_except_fds,
ForkFlags flags,
PidRef *ret_pid);
@ -211,7 +212,18 @@ static inline int pidref_safe_fork(const char *name, ForkFlags flags, PidRef *re
return pidref_safe_fork_full(name, NULL, NULL, 0, flags, ret_pid);
}
int namespace_fork(const char *outer_name, const char *inner_name, const int except_fds[], size_t n_except_fds, ForkFlags flags, int pidns_fd, int mntns_fd, int netns_fd, int userns_fd, int root_fd, pid_t *ret_pid);
int namespace_fork(
const char *outer_name,
const char *inner_name,
int except_fds[],
size_t n_except_fds,
ForkFlags flags,
int pidns_fd,
int mntns_fd,
int netns_fd,
int userns_fd,
int root_fd,
pid_t *ret_pid);
int set_oom_score_adjust(int value);
int get_oom_score_adjust(int *ret);

View file

@ -66,46 +66,6 @@
#define SNDBUF_SIZE (8*1024*1024)
static int shift_fds(int fds[], size_t n_fds) {
if (n_fds <= 0)
return 0;
/* Modifies the fds array! (sorts it) */
assert(fds);
for (int start = 0;;) {
int restart_from = -1;
for (int i = start; i < (int) n_fds; i++) {
int nfd;
/* Already at right index? */
if (fds[i] == i+3)
continue;
nfd = fcntl(fds[i], F_DUPFD, i + 3);
if (nfd < 0)
return -errno;
safe_close(fds[i]);
fds[i] = nfd;
/* Hmm, the fd we wanted isn't free? Then
* let's remember that and try again from here */
if (nfd != i+3 && restart_from < 0)
restart_from = i;
}
if (restart_from < 0)
break;
start = restart_from;
}
return 0;
}
static int flag_fds(
const int fds[],
size_t n_socket_fds,
@ -4900,7 +4860,7 @@ int exec_invoke(
r = close_all_fds(keep_fds, n_keep_fds);
if (r >= 0)
r = shift_fds(params->fds, n_fds);
r = pack_fds(params->fds, n_fds);
if (r >= 0)
r = flag_fds(params->fds, n_socket_fds, n_fds, context->non_blocking);
if (r < 0) {

View file

@ -539,7 +539,7 @@ int fork_agent(const char *name, const int except[], size_t n_except, pid_t *ret
r = safe_fork_full(name,
NULL,
except,
(int*) except, /* safe_fork_full only changes except if you pass in FORK_PACK_FDS, which we don't */
n_except,
FORK_RESET_SIGNALS|FORK_DEATHSIG_SIGTERM|FORK_CLOSE_ALL_FDS|FORK_REOPEN_LOG|FORK_RLIMIT_NOFILE_SAFE,
ret_pid);