]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
selftests/filesystems: add utils.{c,h}
authorChristian Brauner <brauner@kernel.org>
Wed, 19 Feb 2025 13:23:25 +0000 (14:23 +0100)
committerChristian Brauner <brauner@kernel.org>
Wed, 19 Feb 2025 13:32:12 +0000 (14:32 +0100)
Add a new set of helpers that will be used in follow-up patches.

Signed-off-by: Christian Brauner <brauner@kernel.org>
tools/testing/selftests/filesystems/overlayfs/Makefile
tools/testing/selftests/filesystems/utils.c [new file with mode: 0644]
tools/testing/selftests/filesystems/utils.h [new file with mode: 0644]

index e8d1adb021af44588dd7af1049de66833bb584ce..6c661232b3b567ca4cf9436e48e0366b16e668b6 100644 (file)
@@ -1,7 +1,14 @@
 # SPDX-License-Identifier: GPL-2.0
 
-TEST_GEN_PROGS := dev_in_maps set_layers_via_fds
+CFLAGS += -Wall
+CFLAGS += $(KHDR_INCLUDES)
+LDLIBS += -lcap
 
-CFLAGS := -Wall -Werror
+LOCAL_HDRS += wrappers.h log.h
+
+TEST_GEN_PROGS := dev_in_maps
+TEST_GEN_PROGS += set_layers_via_fds
 
 include ../../lib.mk
+
+$(OUTPUT)/set_layers_via_fds: ../utils.c
diff --git a/tools/testing/selftests/filesystems/utils.c b/tools/testing/selftests/filesystems/utils.c
new file mode 100644 (file)
index 0000000..0e8080b
--- /dev/null
@@ -0,0 +1,474 @@
+// SPDX-License-Identifier: GPL-2.0
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+#include <fcntl.h>
+#include <sys/types.h>
+#include <dirent.h>
+#include <grp.h>
+#include <linux/limits.h>
+#include <sched.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/eventfd.h>
+#include <sys/fsuid.h>
+#include <sys/prctl.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <sys/xattr.h>
+
+#include "utils.h"
+
+#define MAX_USERNS_LEVEL 32
+
+#define syserror(format, ...)                           \
+       ({                                              \
+               fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__); \
+               (-errno);                               \
+       })
+
+#define syserror_set(__ret__, format, ...)                    \
+       ({                                                    \
+               typeof(__ret__) __internal_ret__ = (__ret__); \
+               errno = labs(__ret__);                        \
+               fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__);       \
+               __internal_ret__;                             \
+       })
+
+#define STRLITERALLEN(x) (sizeof(""x"") - 1)
+
+#define INTTYPE_TO_STRLEN(type)             \
+       (2 + (sizeof(type) <= 1             \
+                 ? 3                       \
+                 : sizeof(type) <= 2       \
+                       ? 5                 \
+                       : sizeof(type) <= 4 \
+                             ? 10          \
+                             : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
+
+#define list_for_each(__iterator, __list) \
+       for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
+
+typedef enum idmap_type_t {
+       ID_TYPE_UID,
+       ID_TYPE_GID
+} idmap_type_t;
+
+struct id_map {
+       idmap_type_t map_type;
+       __u32 nsid;
+       __u32 hostid;
+       __u32 range;
+};
+
+struct list {
+       void *elem;
+       struct list *next;
+       struct list *prev;
+};
+
+struct userns_hierarchy {
+       int fd_userns;
+       int fd_event;
+       unsigned int level;
+       struct list id_map;
+};
+
+static inline void list_init(struct list *list)
+{
+       list->elem = NULL;
+       list->next = list->prev = list;
+}
+
+static inline int list_empty(const struct list *list)
+{
+       return list == list->next;
+}
+
+static inline void __list_add(struct list *new, struct list *prev, struct list *next)
+{
+       next->prev = new;
+       new->next = next;
+       new->prev = prev;
+       prev->next = new;
+}
+
+static inline void list_add_tail(struct list *head, struct list *list)
+{
+       __list_add(list, head->prev, head);
+}
+
+static inline void list_del(struct list *list)
+{
+       struct list *next, *prev;
+
+       next = list->next;
+       prev = list->prev;
+       next->prev = prev;
+       prev->next = next;
+}
+
+static ssize_t read_nointr(int fd, void *buf, size_t count)
+{
+       ssize_t ret;
+
+       do {
+               ret = read(fd, buf, count);
+       } while (ret < 0 && errno == EINTR);
+
+       return ret;
+}
+
+static ssize_t write_nointr(int fd, const void *buf, size_t count)
+{
+       ssize_t ret;
+
+       do {
+               ret = write(fd, buf, count);
+       } while (ret < 0 && errno == EINTR);
+
+       return ret;
+}
+
+#define __STACK_SIZE (8 * 1024 * 1024)
+static pid_t do_clone(int (*fn)(void *), void *arg, int flags)
+{
+       void *stack;
+
+       stack = malloc(__STACK_SIZE);
+       if (!stack)
+               return -ENOMEM;
+
+#ifdef __ia64__
+       return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
+#else
+       return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
+#endif
+}
+
+static int get_userns_fd_cb(void *data)
+{
+       for (;;)
+               pause();
+       _exit(0);
+}
+
+static int wait_for_pid(pid_t pid)
+{
+       int status, ret;
+
+again:
+       ret = waitpid(pid, &status, 0);
+       if (ret == -1) {
+               if (errno == EINTR)
+                       goto again;
+
+               return -1;
+       }
+
+       if (!WIFEXITED(status))
+               return -1;
+
+       return WEXITSTATUS(status);
+}
+
+static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
+{
+       int fd = -EBADF, setgroups_fd = -EBADF;
+       int fret = -1;
+       int ret;
+       char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
+                 STRLITERALLEN("/setgroups") + 1];
+
+       if (geteuid() != 0 && map_type == ID_TYPE_GID) {
+               ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
+               if (ret < 0 || ret >= sizeof(path))
+                       goto out;
+
+               setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
+               if (setgroups_fd < 0 && errno != ENOENT) {
+                       syserror("Failed to open \"%s\"", path);
+                       goto out;
+               }
+
+               if (setgroups_fd >= 0) {
+                       ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
+                       if (ret != STRLITERALLEN("deny\n")) {
+                               syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
+                               goto out;
+                       }
+               }
+       }
+
+       ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
+       if (ret < 0 || ret >= sizeof(path))
+               goto out;
+
+       fd = open(path, O_WRONLY | O_CLOEXEC);
+       if (fd < 0) {
+               syserror("Failed to open \"%s\"", path);
+               goto out;
+       }
+
+       ret = write_nointr(fd, buf, buf_size);
+       if (ret != buf_size) {
+               syserror("Failed to write %cid mapping to \"%s\"",
+                        map_type == ID_TYPE_UID ? 'u' : 'g', path);
+               goto out;
+       }
+
+       fret = 0;
+out:
+       close(fd);
+       close(setgroups_fd);
+
+       return fret;
+}
+
+static int map_ids_from_idmap(struct list *idmap, pid_t pid)
+{
+       int fill, left;
+       char mapbuf[4096] = {};
+       bool had_entry = false;
+       idmap_type_t map_type, u_or_g;
+
+       if (list_empty(idmap))
+               return 0;
+
+       for (map_type = ID_TYPE_UID, u_or_g = 'u';
+            map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
+               char *pos = mapbuf;
+               int ret;
+               struct list *iterator;
+
+
+               list_for_each(iterator, idmap) {
+                       struct id_map *map = iterator->elem;
+                       if (map->map_type != map_type)
+                               continue;
+
+                       had_entry = true;
+
+                       left = 4096 - (pos - mapbuf);
+                       fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
+                       /*
+                        * The kernel only takes <= 4k for writes to
+                        * /proc/<pid>/{g,u}id_map
+                        */
+                       if (fill <= 0 || fill >= left)
+                               return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
+
+                       pos += fill;
+               }
+               if (!had_entry)
+                       continue;
+
+               ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
+               if (ret < 0)
+                       return syserror("Failed to write mapping: %s", mapbuf);
+
+               memset(mapbuf, 0, sizeof(mapbuf));
+       }
+
+       return 0;
+}
+
+static int get_userns_fd_from_idmap(struct list *idmap)
+{
+       int ret;
+       pid_t pid;
+       char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
+                    STRLITERALLEN("/ns/user") + 1];
+
+       pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
+       if (pid < 0)
+               return -errno;
+
+       ret = map_ids_from_idmap(idmap, pid);
+       if (ret < 0)
+               return ret;
+
+       ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
+       if (ret < 0 || (size_t)ret >= sizeof(path_ns))
+               ret = -EIO;
+       else
+               ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
+
+       (void)kill(pid, SIGKILL);
+       (void)wait_for_pid(pid);
+       return ret;
+}
+
+int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
+{
+       struct list head, uid_mapl, gid_mapl;
+       struct id_map uid_map = {
+               .map_type       = ID_TYPE_UID,
+               .nsid           = nsid,
+               .hostid         = hostid,
+               .range          = range,
+       };
+       struct id_map gid_map = {
+               .map_type       = ID_TYPE_GID,
+               .nsid           = nsid,
+               .hostid         = hostid,
+               .range          = range,
+       };
+
+       list_init(&head);
+       uid_mapl.elem = &uid_map;
+       gid_mapl.elem = &gid_map;
+       list_add_tail(&head, &uid_mapl);
+       list_add_tail(&head, &gid_mapl);
+
+       return get_userns_fd_from_idmap(&head);
+}
+
+bool switch_ids(uid_t uid, gid_t gid)
+{
+       if (setgroups(0, NULL))
+               return syserror("failure: setgroups");
+
+       if (setresgid(gid, gid, gid))
+               return syserror("failure: setresgid");
+
+       if (setresuid(uid, uid, uid))
+               return syserror("failure: setresuid");
+
+       /* Ensure we can access proc files from processes we can ptrace. */
+       if (prctl(PR_SET_DUMPABLE, 1, 0, 0, 0))
+               return syserror("failure: make dumpable");
+
+       return true;
+}
+
+static int create_userns_hierarchy(struct userns_hierarchy *h);
+
+static int userns_fd_cb(void *data)
+{
+       struct userns_hierarchy *h = data;
+       char c;
+       int ret;
+
+       ret = read_nointr(h->fd_event, &c, 1);
+       if (ret < 0)
+               return syserror("failure: read from socketpair");
+
+       /* Only switch ids if someone actually wrote a mapping for us. */
+       if (c == '1') {
+               if (!switch_ids(0, 0))
+                       return syserror("failure: switch ids to 0");
+       }
+
+       ret = write_nointr(h->fd_event, "1", 1);
+       if (ret < 0)
+               return syserror("failure: write to socketpair");
+
+       ret = create_userns_hierarchy(++h);
+       if (ret < 0)
+               return syserror("failure: userns level %d", h->level);
+
+       return 0;
+}
+
+static int create_userns_hierarchy(struct userns_hierarchy *h)
+{
+       int fret = -1;
+       char c;
+       int fd_socket[2];
+       int fd_userns = -EBADF, ret = -1;
+       ssize_t bytes;
+       pid_t pid;
+       char path[256];
+
+       if (h->level == MAX_USERNS_LEVEL)
+               return 0;
+
+       ret = socketpair(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0, fd_socket);
+       if (ret < 0)
+               return syserror("failure: create socketpair");
+
+       /* Note the CLONE_FILES | CLONE_VM when mucking with fds and memory. */
+       h->fd_event = fd_socket[1];
+       pid = do_clone(userns_fd_cb, h, CLONE_NEWUSER | CLONE_FILES | CLONE_VM);
+       if (pid < 0) {
+               syserror("failure: userns level %d", h->level);
+               goto out_close;
+       }
+
+       ret = map_ids_from_idmap(&h->id_map, pid);
+       if (ret < 0) {
+               kill(pid, SIGKILL);
+               syserror("failure: writing id mapping for userns level %d for %d", h->level, pid);
+               goto out_wait;
+       }
+
+       if (!list_empty(&h->id_map))
+               bytes = write_nointr(fd_socket[0], "1", 1); /* Inform the child we wrote a mapping. */
+       else
+               bytes = write_nointr(fd_socket[0], "0", 1); /* Inform the child we didn't write a mapping. */
+       if (bytes < 0) {
+               kill(pid, SIGKILL);
+               syserror("failure: write to socketpair");
+               goto out_wait;
+       }
+
+       /* Wait for child to set*id() and become dumpable. */
+       bytes = read_nointr(fd_socket[0], &c, 1);
+       if (bytes < 0) {
+               kill(pid, SIGKILL);
+               syserror("failure: read from socketpair");
+               goto out_wait;
+       }
+
+       snprintf(path, sizeof(path), "/proc/%d/ns/user", pid);
+       fd_userns = open(path, O_RDONLY | O_CLOEXEC);
+       if (fd_userns < 0) {
+               kill(pid, SIGKILL);
+               syserror("failure: open userns level %d for %d", h->level, pid);
+               goto out_wait;
+       }
+
+       fret = 0;
+
+out_wait:
+       if (!wait_for_pid(pid) && !fret) {
+               h->fd_userns = fd_userns;
+               fd_userns = -EBADF;
+       }
+
+out_close:
+       if (fd_userns >= 0)
+               close(fd_userns);
+       close(fd_socket[0]);
+       close(fd_socket[1]);
+       return fret;
+}
+
+/* caps_down - lower all effective caps */
+int caps_down(void)
+{
+       bool fret = false;
+       cap_t caps = NULL;
+       int ret = -1;
+
+       caps = cap_get_proc();
+       if (!caps)
+               goto out;
+
+       ret = cap_clear_flag(caps, CAP_EFFECTIVE);
+       if (ret)
+               goto out;
+
+       ret = cap_set_proc(caps);
+       if (ret)
+               goto out;
+
+       fret = true;
+
+out:
+       cap_free(caps);
+       return fret;
+}
diff --git a/tools/testing/selftests/filesystems/utils.h b/tools/testing/selftests/filesystems/utils.h
new file mode 100644 (file)
index 0000000..f35001a
--- /dev/null
@@ -0,0 +1,44 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+
+#ifndef __IDMAP_UTILS_H
+#define __IDMAP_UTILS_H
+
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+#include <errno.h>
+#include <linux/types.h>
+#include <sched.h>
+#include <signal.h>
+#include <stdbool.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <syscall.h>
+#include <sys/capability.h>
+#include <sys/fsuid.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+extern int get_userns_fd(unsigned long nsid, unsigned long hostid,
+                        unsigned long range);
+
+extern int caps_down(void);
+
+extern bool switch_ids(uid_t uid, gid_t gid);
+
+static inline bool switch_userns(int fd, uid_t uid, gid_t gid, bool drop_caps)
+{
+       if (setns(fd, CLONE_NEWUSER))
+               return false;
+
+       if (!switch_ids(uid, gid))
+               return false;
+
+       if (drop_caps && !caps_down())
+               return false;
+
+       return true;
+}
+
+#endif /* __IDMAP_UTILS_H */