]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
fd-util: split out close_all_fds() special case handling and call it from close_all_f...
authorLennart Poettering <lennart@poettering.net>
Tue, 12 Oct 2021 14:11:46 +0000 (16:11 +0200)
committerLennart Poettering <lennart@poettering.net>
Wed, 27 Oct 2021 15:56:36 +0000 (17:56 +0200)
The optimization is useful there too.

src/basic/fd-util.c

index 70c036db173ec23b14eb2ed58eb793df170ba501..50666e6375c9d23fb0bd8ad5ac43848e1cee59c3 100644 (file)
@@ -208,7 +208,7 @@ static int get_max_fd(void) {
         return (int) (m - 1);
 }
 
-int close_all_fds_without_malloc(const int except[], size_t n_except) {
+static int close_all_fds_frugal(const int except[], size_t n_except) {
         int max_fd, r = 0;
 
         assert(n_except == 0 || except);
@@ -243,104 +243,143 @@ int close_all_fds_without_malloc(const int except[], size_t n_except) {
         return r;
 }
 
-int close_all_fds(const int except[], size_t n_except) {
-        static bool have_close_range = true; /* Assume we live in the future */
-        _cleanup_closedir_ DIR *d = NULL;
-        struct dirent *de;
-        int r = 0;
+static bool have_close_range = true; /* Assume we live in the future */
 
+static int close_all_fds_special_case(const int except[], size_t n_except) {
         assert(n_except == 0 || except);
 
-        if (have_close_range) {
-                /* In the best case we have close_range() to close all fds between a start and an end fd,
-                 * which we can use on the "inverted" exception array, i.e. all intervals between all
-                 * adjacent pairs from the sorted exception array. This changes loop complexity from O(n)
-                 * where n is number of open fds to O(m⋅log(m)) where m is the number of fds to keep
-                 * open. Given that we assume n ≫ m that's preferable to us. */
+        /* Handles a few common special cases separately, since they are common and can be optimized really
+         * nicely, since we won't need sorting for them. Returns > 0 if the special casing worked, 0
+         * otherwise. */
 
-                if (n_except == 0) {
-                        /* Close everything. Yay! */
+        if (!have_close_range)
+                return 0;
 
-                        if (close_range(3, -1, 0) >= 0)
-                                return 0;
+        switch (n_except) {
 
-                        if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
-                                return -errno;
+        case 0:
+                /* Close everything. Yay! */
 
-                        have_close_range = false;
+                if (close_range(3, -1, 0) >= 0)
+                        return 1;
 
-                } else if (n_except == 1) {
+                if (ERRNO_IS_NOT_SUPPORTED(errno) || ERRNO_IS_PRIVILEGE(errno)) {
+                        have_close_range = false;
+                        return 0;
+                }
 
-                        /* Close all but exactly one, then we don't need no sorting. This is a pretty common
-                         * case, hence let's handle it specially. */
+                return -errno;
 
-                        if ((except[0] <= 3 || close_range(3, except[0]-1, 0) >= 0) &&
-                            (except[0] >= INT_MAX || close_range(MAX(3, except[0]+1), -1, 0) >= 0))
-                                return 0;
+        case 1:
+                /* Close all but exactly one, then we don't need no sorting. This is a pretty common
+                 * case, hence let's handle it specially. */
 
-                        if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
-                                return -errno;
+                if ((except[0] <= 3 || close_range(3, except[0]-1, 0) >= 0) &&
+                    (except[0] >= INT_MAX || close_range(MAX(3, except[0]+1), -1, 0) >= 0))
+                        return 1;
 
+                if (ERRNO_IS_NOT_SUPPORTED(errno) || ERRNO_IS_PRIVILEGE(errno)) {
                         have_close_range = false;
+                        return 0;
+                }
 
-                } else {
-                        _cleanup_free_ int *sorted_malloc = NULL;
-                        size_t n_sorted;
-                        int *sorted;
+                return -errno;
+
+        default:
+                return 0;
+        }
+}
+
+int close_all_fds_without_malloc(const int except[], size_t n_except) {
+        int r;
+
+        assert(n_except == 0 || except);
 
-                        assert(n_except < SIZE_MAX);
-                        n_sorted = n_except + 1;
+        r = close_all_fds_special_case(except, n_except);
+        if (r < 0)
+                return r;
+        if (r > 0) /* special case worked! */
+                return 0;
+
+        return close_all_fds_frugal(except, n_except);
+}
+
+int close_all_fds(const int except[], size_t n_except) {
+        _cleanup_closedir_ DIR *d = NULL;
+        struct dirent *de;
+        int r = 0;
 
-                        if (n_sorted > 64) /* Use heap for large numbers of fds, stack otherwise */
-                                sorted = sorted_malloc = new(int, n_sorted);
-                        else
-                                sorted = newa(int, n_sorted);
+        assert(n_except == 0 || except);
 
-                        if (sorted) {
-                                memcpy(sorted, except, n_except * sizeof(int));
+        r = close_all_fds_special_case(except, n_except);
+        if (r < 0)
+                return r;
+        if (r > 0) /* special case worked! */
+                return 0;
 
-                                /* Let's add fd 2 to the list of fds, to simplify the loop below, as this
-                                 * allows us to cover the head of the array the same way as the body */
-                                sorted[n_sorted-1] = 2;
+        if (have_close_range) {
+                _cleanup_free_ int *sorted_malloc = NULL;
+                size_t n_sorted;
+                int *sorted;
 
-                                typesafe_qsort(sorted, n_sorted, cmp_int);
+                /* In the best case we have close_range() to close all fds between a start and an end fd,
+                 * which we can use on the "inverted" exception array, i.e. all intervals between all
+                 * adjacent pairs from the sorted exception array. This changes loop complexity from O(n)
+                 * where n is number of open fds to O(m⋅log(m)) where m is the number of fds to keep
+                 * open. Given that we assume n ≫ m that's preferable to us. */
 
-                                for (size_t i = 0; i < n_sorted-1; i++) {
-                                        int start, end;
+                assert(n_except < SIZE_MAX);
+                n_sorted = n_except + 1;
 
-                                        start = MAX(sorted[i], 2); /* The first three fds shall always remain open */
-                                        end = MAX(sorted[i+1], 2);
+                if (n_sorted > 64) /* Use heap for large numbers of fds, stack otherwise */
+                        sorted = sorted_malloc = new(int, n_sorted);
+                else
+                        sorted = newa(int, n_sorted);
 
-                                        assert(end >= start);
+                if (sorted) {
+                        memcpy(sorted, except, n_except * sizeof(int));
 
-                                        if (end - start <= 1)
-                                                continue;
+                        /* Let's add fd 2 to the list of fds, to simplify the loop below, as this
+                         * allows us to cover the head of the array the same way as the body */
+                        sorted[n_sorted-1] = 2;
 
-                                        /* Close everything between the start and end fds (both of which shall stay open) */
-                                        if (close_range(start + 1, end - 1, 0) < 0) {
-                                                if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
-                                                        return -errno;
+                        typesafe_qsort(sorted, n_sorted, cmp_int);
 
-                                                have_close_range = false;
-                                                break;
-                                        }
-                                }
+                        for (size_t i = 0; i < n_sorted-1; i++) {
+                                int start, end;
 
-                                if (have_close_range) {
-                                        /* The loop succeeded. Let's now close everything beyond the end */
+                                start = MAX(sorted[i], 2); /* The first three fds shall always remain open */
+                                end = MAX(sorted[i+1], 2);
 
-                                        if (sorted[n_sorted-1] >= INT_MAX) /* Dont let the addition below overflow */
-                                                return 0;
+                                assert(end >= start);
 
-                                        if (close_range(sorted[n_sorted-1] + 1, -1, 0) >= 0)
-                                                return 0;
+                                if (end - start <= 1)
+                                        continue;
 
+                                /* Close everything between the start and end fds (both of which shall stay open) */
+                                if (close_range(start + 1, end - 1, 0) < 0) {
                                         if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
                                                 return -errno;
 
                                         have_close_range = false;
+                                        break;
                                 }
                         }
+
+                        if (have_close_range) {
+                                /* The loop succeeded. Let's now close everything beyond the end */
+
+                                if (sorted[n_sorted-1] >= INT_MAX) /* Dont let the addition below overflow */
+                                        return 0;
+
+                                if (close_range(sorted[n_sorted-1] + 1, -1, 0) >= 0)
+                                        return 0;
+
+                                if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno))
+                                        return -errno;
+
+                                have_close_range = false;
+                        }
                 }
 
                 /* Fallback on OOM or if close_range() is not supported */
@@ -348,7 +387,7 @@ int close_all_fds(const int except[], size_t n_except) {
 
         d = opendir("/proc/self/fd");
         if (!d)
-                return close_all_fds_without_malloc(except, n_except); /* ultimate fallback if /proc/ is not available */
+                return close_all_fds_frugal(except, n_except); /* ultimate fallback if /proc/ is not available */
 
         FOREACH_DIRENT(de, d, return -errno) {
                 int fd = -1, q;