]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
fd-util: port close_all_fds() to close_range()
authorLennart Poettering <lennart@poettering.net>
Tue, 13 Oct 2020 16:06:45 +0000 (18:06 +0200)
committerLennart Poettering <lennart@poettering.net>
Wed, 14 Oct 2020 08:40:29 +0000 (10:40 +0200)
src/basic/fd-util.c

index db869cbd5499a8827e8bb72ae5d519a0a45663cb..e37b6944a8a6fe7c417e68f61ca6a9a5efddfbdc 100644 (file)
@@ -21,6 +21,7 @@
 #include "path-util.h"
 #include "process-util.h"
 #include "socket-util.h"
+#include "sort-util.h"
 #include "stat-util.h"
 #include "stdio-util.h"
 #include "tmpfile-util.h"
@@ -210,13 +211,102 @@ static int get_max_fd(void) {
         return (int) (m - 1);
 }
 
+static int cmp_int(const int *a, const int *b) {
+        return CMP(*a, *b);
+}
+
 int close_all_fds(const int except[], size_t n_except) {
+        static bool have_close_range = true; /* Assume we live in the future */
         _cleanup_closedir_ DIR *d = NULL;
         struct dirent *de;
         int r = 0;
 
         assert(n_except == 0 || except);
 
+        if (have_close_range) {
+                /* In the best case we have close_range() to close all fds between a start and an end fd,
+                 * which we can use on the "inverted" exception array, i.e. all intervals between all
+                 * adjacent pairs from the sorted exception array. This changes loop complexity from O(n)
+                 * where n is number of open fds to O(m⋅log(m)) where m is the number of fds to keep
+                 * open. Given that we assume n ≫ m that's preferable to us. */
+
+                if (n_except == 0) {
+                        /* Close everything. Yay! */
+
+                        if (close_range(3, -1, 0) >= 0)
+                                return 1;
+
+                        if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
+                                return -errno;
+
+                        have_close_range = false;
+                } else {
+                        _cleanup_free_ int *sorted_malloc = NULL;
+                        size_t n_sorted;
+                        int *sorted;
+
+                        assert(n_except < SIZE_MAX);
+                        n_sorted = n_except + 1;
+
+                        if (n_sorted > 64) /* Use heap for large numbers of fds, stack otherwise */
+                                sorted = sorted_malloc = new(int, n_sorted);
+                        else
+                                sorted = newa(int, n_sorted);
+
+                        if (sorted) {
+                                int c = 0;
+
+                                memcpy(sorted, except, n_except * sizeof(int));
+
+                                /* Let's add fd 2 to the list of fds, to simplify the loop below, as this
+                                 * allows us to cover the head of the array the same way as the body */
+                                sorted[n_sorted-1] = 2;
+
+                                typesafe_qsort(sorted, n_sorted, cmp_int);
+
+                                for (size_t i = 0; i < n_sorted-1; i++) {
+                                        int start, end;
+
+                                        start = MAX(sorted[i], 2); /* The first three fds shall always remain open */
+                                        end = MAX(sorted[i+1], 2);
+
+                                        assert(end >= start);
+
+                                        if (end - start <= 1)
+                                                continue;
+
+                                        /* Close everything between the start and end fds (both of which shall stay open) */
+                                        if (close_range(start + 1, end - 1, 0) < 0) {
+                                                if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
+                                                        return -errno;
+
+                                                have_close_range = false;
+                                                break;
+                                        }
+
+                                        c += end - start - 1;
+                                }
+
+                                if (have_close_range) {
+                                        /* The loop succeeded. Let's now close everything beyond the end */
+
+                                        if (sorted[n_sorted-1] >= INT_MAX) /* Dont let the addition below overflow */
+                                                return c;
+
+                                        if (close_range(sorted[n_sorted-1] + 1, -1, 0) >= 0)
+                                                return c + 1;
+
+                                        if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
+                                                return -errno;
+
+                                        have_close_range = false;
+                                }
+                        }
+                }
+
+                /* Fallback on OOM or if close_range() is not supported */
+        }
+
         d = opendir("/proc/self/fd");
         if (!d) {
                 int fd, max_fd;