]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
varlink: support varlink communication via distinct input/output fds
authorLennart Poettering <lennart@poettering.net>
Wed, 24 Apr 2024 16:54:07 +0000 (18:54 +0200)
committerLennart Poettering <lennart@poettering.net>
Thu, 27 Jun 2024 07:41:54 +0000 (09:41 +0200)
When invoking another process via a pair of pipes it makes sense to
allow reading from one fd, and writing from another. Teach our varlink
code to do so optionally.

(sd-bus supports something similar, fill the gap).

This is preparation for a later commit that uses this to talk to remote
SSH invocations via pipes.

src/shared/varlink.c

index d0a86157a284ca223b9191f29917f6d5c1faebd4..3dec6dbfa7c5612fc52c19e93ab3dcbfecbae400 100644 (file)
@@ -139,7 +139,8 @@ struct Varlink {
                           * at most. */
         unsigned n_pending;
 
-        int fd;
+        int input_fd;
+        int output_fd;
 
         char *input_buffer; /* valid data starts at input_buffer_index, ends at input_buffer_index+input_buffer_size */
         size_t input_buffer_index;
@@ -185,7 +186,8 @@ struct Varlink {
 
         bool write_disconnected:1;
         bool read_disconnected:1;
-        bool prefer_read_write:1;
+        bool prefer_read:1;
+        bool prefer_write:1;
         bool got_pollhup:1;
 
         bool allow_fd_passing_input:1;
@@ -203,7 +205,8 @@ struct Varlink {
         char *description;
 
         sd_event *event;
-        sd_event_source *io_event_source;
+        sd_event_source *input_event_source;
+        sd_event_source *output_event_source;
         sd_event_source *time_event_source;
         sd_event_source *quit_event_source;
         sd_event_source *defer_event_source;
@@ -357,7 +360,8 @@ static int varlink_new(Varlink **ret) {
 
         *v = (Varlink) {
                 .n_ref = 1,
-                .fd = -EBADF,
+                .input_fd = -EBADF,
+                .output_fd = -EBADF,
 
                 .state = _VARLINK_STATE_INVALID,
 
@@ -387,11 +391,11 @@ int varlink_connect_address(Varlink **ret, const char *address) {
         if (r < 0)
                 return log_debug_errno(r, "Failed to create varlink object: %m");
 
-        v->fd = socket(AF_UNIX, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0);
-        if (v->fd < 0)
+        v->input_fd = socket(AF_UNIX, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0);
+        if (v->input_fd < 0)
                 return log_debug_errno(errno, "Failed to create AF_UNIX socket: %m");
 
-        v->fd = fd_move_above_stdio(v->fd);
+        v->output_fd = v->input_fd = fd_move_above_stdio(v->input_fd);
         v->af = AF_UNIX;
 
         r = sockaddr_un_set_path(&sockaddr.un, address);
@@ -402,9 +406,9 @@ int varlink_connect_address(Varlink **ret, const char *address) {
                 /* This is a file system path, and too long to fit into sockaddr_un. Let's connect via O_PATH
                  * to this socket. */
 
-                r = connect_unix_path(v->fd, AT_FDCWD, address);
+                r = connect_unix_path(v->input_fd, AT_FDCWD, address);
         } else
-                r = RET_NERRNO(connect(v->fd, &sockaddr.sa, r));
+                r = RET_NERRNO(connect(v->input_fd, &sockaddr.sa, r));
 
         if (r < 0) {
                 if (!IN_SET(r, -EAGAIN, -EINPROGRESS))
@@ -507,7 +511,7 @@ int varlink_connect_exec(Varlink **ret, const char *_command, char **_argv) {
         if (r < 0)
                 return log_debug_errno(r, "Failed to create varlink object: %m");
 
-        v->fd = TAKE_FD(pair[0]);
+        v->output_fd = v->input_fd = TAKE_FD(pair[0]);
         v->af = AF_UNIX;
         v->exec_pid = TAKE_PID(pid);
         varlink_set_state(v, VARLINK_IDLE_CLIENT);
@@ -583,7 +587,7 @@ static int varlink_connect_ssh(Varlink **ret, const char *where) {
         if (r < 0)
                 return log_debug_errno(r, "Failed to create varlink object: %m");
 
-        v->fd = TAKE_FD(pair[0]);
+        v->output_fd = v->input_fd = TAKE_FD(pair[0]);
         v->af = AF_UNIX;
         v->exec_pid = TAKE_PID(pid);
         varlink_set_state(v, VARLINK_IDLE_CLIENT);
@@ -663,7 +667,7 @@ int varlink_connect_fd(Varlink **ret, int fd) {
         if (r < 0)
                 return log_debug_errno(r, "Failed to create varlink object: %m");
 
-        v->fd = fd;
+        v->output_fd = v->input_fd = fd;
         v->af = -1,
         varlink_set_state(v, VARLINK_IDLE_CLIENT);
 
@@ -681,7 +685,8 @@ int varlink_connect_fd(Varlink **ret, int fd) {
 static void varlink_detach_event_sources(Varlink *v) {
         assert(v);
 
-        v->io_event_source = sd_event_source_disable_unref(v->io_event_source);
+        v->input_event_source = sd_event_source_disable_unref(v->input_event_source);
+        v->output_event_source = sd_event_source_disable_unref(v->output_event_source);
         v->time_event_source = sd_event_source_disable_unref(v->time_event_source);
         v->quit_event_source = sd_event_source_disable_unref(v->quit_event_source);
         v->defer_event_source = sd_event_source_disable_unref(v->defer_event_source);
@@ -706,7 +711,11 @@ static void varlink_clear(Varlink *v) {
 
         varlink_detach_event_sources(v);
 
-        v->fd = safe_close(v->fd);
+        if (v->input_fd != v->output_fd) {
+                v->input_fd = safe_close(v->input_fd);
+                v->output_fd = safe_close(v->output_fd);
+        } else
+                v->output_fd = v->input_fd = safe_close(v->input_fd);
 
         varlink_clear_current(v);
 
@@ -821,7 +830,7 @@ static int varlink_write(Varlink *v) {
         if (v->output_buffer_size == 0)
                 return 0;
 
-        assert(v->fd >= 0);
+        assert(v->output_fd >= 0);
 
         if (v->n_output_fds > 0) { /* If we shall send fds along, we must use sendmsg() */
                 struct iovec iov = {
@@ -842,20 +851,20 @@ static int varlink_write(Varlink *v) {
                 control->cmsg_type = SCM_RIGHTS;
                 memcpy(CMSG_DATA(control), v->output_fds, sizeof(int) * v->n_output_fds);
 
-                n = sendmsg(v->fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
+                n = sendmsg(v->output_fd, &mh, MSG_DONTWAIT|MSG_NOSIGNAL);
         } else {
                 /* We generally prefer recv()/send() (mostly because of MSG_NOSIGNAL) but also want to be compatible
                  * with non-socket IO, hence fall back automatically.
                  *
                  * Use a local variable to help gcc figure out that we set 'n' in all cases. */
-                bool prefer_write = v->prefer_read_write;
+                bool prefer_write = v->prefer_write;
                 if (!prefer_write) {
-                        n = send(v->fd, v->output_buffer + v->output_buffer_index, v->output_buffer_size, MSG_DONTWAIT|MSG_NOSIGNAL);
+                        n = send(v->output_fd, v->output_buffer + v->output_buffer_index, v->output_buffer_size, MSG_DONTWAIT|MSG_NOSIGNAL);
                         if (n < 0 && errno == ENOTSOCK)
-                                prefer_write = v->prefer_read_write = true;
+                                prefer_write = v->prefer_write = true;
                 }
                 if (prefer_write)
-                        n = write(v->fd, v->output_buffer + v->output_buffer_index, v->output_buffer_size);
+                        n = write(v->output_fd, v->output_buffer + v->output_buffer_index, v->output_buffer_size);
         }
         if (n < 0) {
                 if (errno == EAGAIN)
@@ -914,7 +923,7 @@ static int varlink_read(Varlink *v) {
         if (v->input_buffer_size >= VARLINK_BUFFER_MAX)
                 return -ENOBUFS;
 
-        assert(v->fd >= 0);
+        assert(v->input_fd >= 0);
 
         if (MALLOC_SIZEOF_SAFE(v->input_buffer) <= v->input_buffer_index + v->input_buffer_size) {
                 size_t add;
@@ -961,16 +970,16 @@ static int varlink_read(Varlink *v) {
                         .msg_controllen = v->input_control_buffer_size,
                 };
 
-                n = recvmsg_safe(v->fd, &mh, MSG_DONTWAIT|MSG_CMSG_CLOEXEC);
+                n = recvmsg_safe(v->input_fd, &mh, MSG_DONTWAIT|MSG_CMSG_CLOEXEC);
         } else {
-                bool prefer_read = v->prefer_read_write;
+                bool prefer_read = v->prefer_read;
                 if (!prefer_read) {
-                        n = recv(v->fd, p, rs, MSG_DONTWAIT);
+                        n = recv(v->input_fd, p, rs, MSG_DONTWAIT);
                         if (n < 0 && errno == ENOTSOCK)
-                                prefer_read = v->prefer_read_write = true;
+                                prefer_read = v->prefer_read = true;
                 }
                 if (prefer_read)
-                        n = read(v->fd, p, rs);
+                        n = read(v->input_fd, p, rs);
         }
         if (n < 0) {
                 if (errno == EAGAIN)
@@ -1666,7 +1675,7 @@ static void handle_revents(Varlink *v, int revents) {
 }
 
 int varlink_wait(Varlink *v, usec_t timeout) {
-        int r, fd, events;
+        int r, events;
         usec_t t;
 
         assert_return(v, -EINVAL);
@@ -1691,22 +1700,42 @@ int varlink_wait(Varlink *v, usec_t timeout) {
             (t == USEC_INFINITY || timeout < t))
                 t = timeout;
 
-        fd = varlink_get_fd(v);
-        if (fd < 0)
-                return fd;
-
         events = varlink_get_events(v);
         if (events < 0)
                 return events;
 
-        r = fd_wait_for_event(fd, events, t);
+        struct pollfd pollfd[2];
+        size_t n_poll_fd = 0;
+
+        if (v->input_fd == v->output_fd) {
+                pollfd[n_poll_fd++] = (struct pollfd) {
+                        .fd = v->input_fd,
+                        .events = events,
+                };
+        } else {
+                pollfd[n_poll_fd++] = (struct pollfd) {
+                        .fd = v->input_fd,
+                        .events = events & POLLIN,
+                };
+                pollfd[n_poll_fd++] = (struct pollfd) {
+                        .fd = v->output_fd,
+                        .events = events & POLLOUT,
+                };
+        };
+
+        r = ppoll_usec(pollfd, n_poll_fd, t);
         if (ERRNO_IS_NEG_TRANSIENT(r)) /* Treat EINTR as not a timeout, but also nothing happened, and
                                         * the caller gets a chance to call back into us */
                 return 1;
         if (r <= 0)
                 return r;
 
-        handle_revents(v, r);
+        /* Merge the seen events into one */
+        int revents = 0;
+        FOREACH_ARRAY(p, pollfd, n_poll_fd)
+                revents |= p->revents;
+
+        handle_revents(v, revents);
         return 1;
 }
 
@@ -1725,10 +1754,12 @@ int varlink_get_fd(Varlink *v) {
 
         if (v->state == VARLINK_DISCONNECTED)
                 return varlink_log_errno(v, SYNTHETIC_ERRNO(ENOTCONN), "Not connected.");
-        if (v->fd < 0)
+        if (v->input_fd != v->output_fd)
+                return varlink_log_errno(v, SYNTHETIC_ERRNO(EBADF), "Separate file descriptors for input/output set.");
+        if (v->input_fd < 0)
                 return varlink_log_errno(v, SYNTHETIC_ERRNO(EBADF), "No valid fd.");
 
-        return v->fd;
+        return v->input_fd;
 }
 
 int varlink_get_events(Varlink *v) {
@@ -1797,7 +1828,7 @@ int varlink_flush(Varlink *v) {
                         continue;
                 }
 
-                r = fd_wait_for_event(v->fd, POLLOUT, USEC_INFINITY);
+                r = fd_wait_for_event(v->output_fd, POLLOUT, USEC_INFINITY);
                 if (ERRNO_IS_NEG_TRANSIENT(r))
                         continue;
                 if (r < 0)
@@ -2794,7 +2825,12 @@ static int varlink_acquire_ucred(Varlink *v) {
         if (v->ucred_acquired)
                 return 0;
 
-        r = getpeercred(v->fd, &v->ucred);
+        /* If we are connected asymmetrically, let's refuse, since it's not clear if caller wants to know
+         * peer on read or write fd */
+        if (v->input_fd != v->output_fd)
+                return -EBADF;
+
+        r = getpeercred(v->input_fd, &v->ucred);
         if (r < 0)
                 return r;
 
@@ -2859,7 +2895,10 @@ static int varlink_acquire_pidfd(Varlink *v) {
         if (v->peer_pidfd >= 0)
                 return 0;
 
-        v->peer_pidfd = getpeerpidfd(v->fd);
+        if (v->input_fd != v->output_fd)
+                return -EBADF;
+
+        v->peer_pidfd = getpeerpidfd(v->input_fd);
         if (v->peer_pidfd < 0)
                 return v->peer_pidfd;
 
@@ -2962,7 +3001,14 @@ static int prepare_callback(sd_event_source *s, void *userdata) {
         if (e < 0)
                 return e;
 
-        r = sd_event_source_set_io_events(v->io_event_source, e);
+        if (v->input_event_source == v->output_event_source)
+                /* Same fd for input + output */
+                r = sd_event_source_set_io_events(v->input_event_source, e);
+        else {
+                r = sd_event_source_set_io_events(v->input_event_source, e & EPOLLIN);
+                if (r >= 0)
+                        r = sd_event_source_set_io_events(v->output_event_source, e & EPOLLOUT);
+        }
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to set source events: %m");
 
@@ -3029,19 +3075,33 @@ int varlink_attach_event(Varlink *v, sd_event *e, int64_t priority) {
 
         (void) sd_event_source_set_description(v->quit_event_source, "varlink-quit");
 
-        r = sd_event_add_io(v->event, &v->io_event_source, v->fd, 0, io_callback, v);
+        r = sd_event_add_io(v->event, &v->input_event_source, v->input_fd, 0, io_callback, v);
         if (r < 0)
                 goto fail;
 
-        r = sd_event_source_set_prepare(v->io_event_source, prepare_callback);
+        r = sd_event_source_set_prepare(v->input_event_source, prepare_callback);
         if (r < 0)
                 goto fail;
 
-        r = sd_event_source_set_priority(v->io_event_source, priority);
+        r = sd_event_source_set_priority(v->input_event_source, priority);
         if (r < 0)
                 goto fail;
 
-        (void) sd_event_source_set_description(v->io_event_source, "varlink-io");
+        (void) sd_event_source_set_description(v->input_event_source, "varlink-input");
+
+        if (v->input_fd == v->output_fd)
+                v->output_event_source = sd_event_source_ref(v->input_event_source);
+        else {
+                r = sd_event_add_io(v->event, &v->output_event_source, v->output_fd, 0, io_callback, v);
+                if (r < 0)
+                        goto fail;
+
+                r = sd_event_source_set_priority(v->output_event_source, priority);
+                if (r < 0)
+                        goto fail;
+
+                (void) sd_event_source_set_description(v->output_event_source, "varlink-output");
+        }
 
         r = sd_event_add_defer(v->event, &v->defer_event_source, defer_callback, v);
         if (r < 0)
@@ -3187,16 +3247,23 @@ static int verify_unix_socket(Varlink *v) {
          *    • otherwise: v->af contains the address family we determined */
 
         if (v->af < 0) {
+                /* If we have distinct input + output fds, we don't consider ourselves to be connected via a regular
+                 * AF_UNIX socket. */
+                if (v->input_fd != v->output_fd) {
+                        v->af = AF_UNSPEC;
+                        return -ENOTSOCK;
+                }
+
                 struct stat st;
 
-                if (fstat(v->fd, &st) < 0)
+                if (fstat(v->input_fd, &st) < 0)
                         return -errno;
                 if (!S_ISSOCK(st.st_mode)) {
                         v->af = AF_UNSPEC;
                         return -ENOTSOCK;
                 }
 
-                v->af = socket_get_family(v->fd);
+                v->af = socket_get_family(v->input_fd);
                 if (v->af < 0)
                         return v->af;
         }
@@ -3411,7 +3478,7 @@ int varlink_server_add_connection(VarlinkServer *server, int fd, Varlink **ret)
         if (r < 0)
                 return r;
 
-        v->fd = fd;
+        v->input_fd = v->output_fd = fd;
         if (server->flags & VARLINK_SERVER_INHERIT_USERDATA)
                 v->userdata = server->userdata;
 
@@ -3421,7 +3488,7 @@ int varlink_server_add_connection(VarlinkServer *server, int fd, Varlink **ret)
         }
 
         _cleanup_free_ char *desc = NULL;
-        if (asprintf(&desc, "%s-%i", varlink_server_description(server), v->fd) >= 0)
+        if (asprintf(&desc, "%s-%i", varlink_server_description(server), fd) >= 0)
                 v->description = TAKE_PTR(desc);
 
         /* Link up the server and the connection, and take reference in both directions. Note that the
@@ -3436,7 +3503,8 @@ int varlink_server_add_connection(VarlinkServer *server, int fd, Varlink **ret)
                 r = varlink_attach_event(v, server->event, server->event_priority);
                 if (r < 0) {
                         varlink_log_errno(v, r, "Failed to attach new connection: %m");
-                        v->fd = -EBADF; /* take the fd out of the connection again */
+                        TAKE_FD(v->input_fd); /* take the fd out of the connection again */
+                        TAKE_FD(v->output_fd);
                         varlink_close(v);
                         return r;
                 }