]> git.ipfire.org Git - thirdparty/dhcpcd.git/commitdiff
privsep: Ensure we recv for real after a successful recv MSG_PEEK
authorRoy Marples <roy@marples.name>
Tue, 17 Feb 2026 12:48:55 +0000 (12:48 +0000)
committerGitHub <noreply@github.com>
Tue, 17 Feb 2026 12:48:55 +0000 (12:48 +0000)
* privsep: Ensure we recv for real after a successful recv MSG_PEEK

Adjust the code flow so that the same errors would be caught
after the final recv.
This ensures we read what is really meant for us and not
something silly.

Return EBADMSG on recvmsg len mismatch.

src/privsep-root.c
src/privsep.c
src/privsep.h

index 28dfcc4d45af821ea6bf4ad6c914bab37162bd15..f4e65b0e068fe71b115d0cee439f3084870418c8 100644 (file)
@@ -71,108 +71,119 @@ struct psr_ctx {
        struct psr_error psr_error;
        size_t psr_datalen;
        void *psr_data;
-       size_t psr_mdatalen;
-       void *psr_mdata;
-       bool psr_usemdata;
+       bool psr_mallocdata;
 };
 
 static ssize_t
-ps_root_readerrorcb(struct psr_ctx *psr_ctx)
+ps_root_readerrorcb(struct psr_ctx *pc)
 {
-       struct dhcpcd_ctx *ctx = psr_ctx->psr_ctx;
+       struct dhcpcd_ctx *ctx = pc->psr_ctx;
        int fd = PS_ROOT_FD(ctx);
-       struct psr_error *psr_error = &psr_ctx->psr_error;
+       struct psr_error *psr_error = &pc->psr_error;
        struct iovec iov[] = {
                { .iov_base = psr_error, .iov_len = sizeof(*psr_error) },
-               { .iov_base = NULL, .iov_len = 0 },
+               { .iov_base = pc->psr_data, .iov_len = pc->psr_datalen },
        };
+       struct msghdr msg = { .msg_iov = iov, .msg_iovlen = __arraycount(iov) };
        ssize_t len;
 
 #define PSR_ERROR(e)                           \
        do {                                    \
-               psr_error->psr_result = -1;     \
                psr_error->psr_errno = (e);     \
-               return -1;                      \
+               goto error;                     \
        } while (0 /* CONSTCOND */)
 
        if (eloop_waitfd(fd) == -1)
                PSR_ERROR(errno);
 
-       len = recv(fd, psr_error, sizeof(*psr_error), MSG_PEEK);
+       if (!pc->psr_mallocdata)
+               goto recv;
+
+       /* We peek at the psr_error structure to tell us how much of a buffer
+        * we need to read the whole packet. */
+       msg.msg_iovlen--;
+       len = recvmsg(fd, &msg, MSG_PEEK | MSG_WAITALL);
        if (len == -1)
                PSR_ERROR(errno);
-       else if ((size_t)len < sizeof(*psr_error))
-               PSR_ERROR(EINVAL);
 
-       if (psr_error->psr_datalen > SSIZE_MAX)
-               PSR_ERROR(ENOBUFS);
-       if (psr_ctx->psr_usemdata &&
-           psr_error->psr_datalen > psr_ctx->psr_mdatalen)
-       {
-               void *d = realloc(psr_ctx->psr_mdata, psr_error->psr_datalen);
-               if (d == NULL)
-                       PSR_ERROR(errno);
-               psr_ctx->psr_mdata = d;
-               psr_ctx->psr_mdatalen = psr_error->psr_datalen;
+       /* After this point, we MUST do another recvmsg even on a failure
+        * to remove the message after peeking. */
+       if ((size_t)len < sizeof(*psr_error)) {
+               /* We can't use the header to work out buffers, so
+                * remove the message and bail. */
+               (void)recvmsg(fd, &msg, MSG_WAITALL);
+               PSR_ERROR(EINVAL);
        }
-       if (psr_error->psr_datalen != 0) {
-               if (psr_ctx->psr_usemdata)
-                       iov[1].iov_base = psr_ctx->psr_mdata;
-               else {
-                       if (psr_error->psr_datalen > psr_ctx->psr_datalen)
-                               PSR_ERROR(ENOBUFS);
-                       iov[1].iov_base = psr_ctx->psr_data;
-               }
+
+       /* No data to read? Unlikely but ... */
+       if (psr_error->psr_datalen == 0)
+               goto recv;
+
+       pc->psr_data = malloc(psr_error->psr_datalen);
+       if (pc->psr_data != NULL) {
+               iov[1].iov_base = pc->psr_data;
                iov[1].iov_len = psr_error->psr_datalen;
+               msg.msg_iovlen++;
        }
 
-       len = readv(fd, iov, __arraycount(iov));
+recv:
+       len = recvmsg(fd, &msg, MSG_WAITALL);
        if (len == -1)
                PSR_ERROR(errno);
-       else if ((size_t)len != sizeof(*psr_error) + psr_error->psr_datalen)
+       else if ((size_t)len < sizeof(*psr_error))
                PSR_ERROR(EINVAL);
+       else if (msg.msg_flags & MSG_TRUNC)
+               PSR_ERROR(ENOBUFS);
+       else if ((size_t)len != sizeof(*psr_error) + psr_error->psr_datalen) {
+#ifdef PRIVSEP_DEBUG
+               logerrx("%s: recvmsg returned %zd, expecting %zu", __func__,
+                   len, sizeof(*psr_error) + psr_error->psr_datalen);
+#endif
+               PSR_ERROR(EBADMSG);
+       }
        return len;
+
+error:
+       psr_error->psr_result = -1;
+       if (pc->psr_mallocdata && pc->psr_data != NULL) {
+               free(pc->psr_data);
+               pc->psr_data = NULL;
+       }
+       return -1;
 }
 
 ssize_t
 ps_root_readerror(struct dhcpcd_ctx *ctx, void *data, size_t len)
 {
-       struct psr_ctx *pc = ctx->ps_root->psp_data;
+       struct psr_ctx pc = {
+               .psr_ctx = ctx,
+               .psr_data = data,
+               .psr_datalen = len,
+               .psr_mallocdata = false
+       };
 
-       pc->psr_data = data;
-       pc->psr_datalen = len;
-       pc->psr_usemdata = false;
-       ps_root_readerrorcb(pc);
+       ps_root_readerrorcb(&pc);
 
-       errno = pc->psr_error.psr_errno;
-       return pc->psr_error.psr_result;
+       errno = pc.psr_error.psr_errno;
+       return pc.psr_error.psr_result;
 }
 
 ssize_t
 ps_root_mreaderror(struct dhcpcd_ctx *ctx, void **data, size_t *len)
 {
-       struct psr_ctx *pc = ctx->ps_root->psp_data;
-       void *d;
+       struct psr_ctx pc = {
+               .psr_ctx = ctx,
+               .psr_data = NULL,
+               .psr_datalen = 0,
+               .psr_mallocdata = true
+       };
 
-       pc->psr_usemdata = true;
-       ps_root_readerrorcb(pc);
+       ps_root_readerrorcb(&pc);
 
-       if (pc->psr_error.psr_datalen != 0) {
-               if (pc->psr_error.psr_datalen > pc->psr_mdatalen) {
-                       errno = EINVAL;
-                       return -1;
-               }
-               d = malloc(pc->psr_error.psr_datalen);
-               if (d == NULL)
-                       return -1;
-               memcpy(d, pc->psr_mdata, pc->psr_error.psr_datalen);
-       } else
-               d = NULL;
-
-       errno = pc->psr_error.psr_errno;
-       *data = d;
-       *len = pc->psr_error.psr_datalen;
-       return pc->psr_error.psr_result;
+       errno = pc.psr_error.psr_errno;
+       *data = pc.psr_data;
+       *len = pc.psr_error.psr_datalen;
+       return pc.psr_error.psr_result;
 }
 
 static ssize_t
@@ -196,6 +207,8 @@ ps_root_writeerror(struct dhcpcd_ctx *ctx, ssize_t result,
        logdebugx("%s: result %zd errno %d", __func__, result, errno);
 #endif
 
+       if (len == 0)
+               msg.msg_iovlen = 1;
        err = sendmsg(fd, &msg, MSG_EOR);
 
        /* Error sending the message? Try sending the error of sending. */
@@ -204,8 +217,8 @@ ps_root_writeerror(struct dhcpcd_ctx *ctx, ssize_t result,
                    __func__, result, data, len);
                psr.psr_result = err;
                psr.psr_errno = errno;
-               iov[1].iov_base = NULL;
-               iov[1].iov_len = 0;
+               psr.psr_datalen = 0;
+               msg.msg_iovlen = 1;
                err = sendmsg(fd, &msg, MSG_EOR);
        }
 
@@ -602,7 +615,7 @@ ps_root_recvmsgcb(void *arg, struct ps_msghdr *psm, struct msghdr *msg)
                break;
        }
 
-       err = ps_root_writeerror(ctx, err, rlen != 0 ? rdata : 0, rlen);
+       err = ps_root_writeerror(ctx, err, rdata, rlen);
        if (free_rdata)
                free(rdata);
        return err;
@@ -843,17 +856,6 @@ ps_root_log(void *arg, unsigned short events)
                logerr(__func__);
 }
 
-static void
-ps_root_freepsdata(void *arg)
-{
-       struct psr_ctx *pc = arg;
-
-       if (pc == NULL)
-               return;
-       free(pc->psr_mdata);
-       free(pc);
-}
-
 pid_t
 ps_root_start(struct dhcpcd_ctx *ctx)
 {
@@ -864,7 +866,6 @@ ps_root_start(struct dhcpcd_ctx *ctx)
        struct ps_process *psp;
        int logfd[2] = { -1, -1}, datafd[2] = { -1, -1};
        pid_t pid;
-       struct psr_ctx *pc;
 
        if (xsocketpair(AF_UNIX, SOCK_SEQPACKET | SOCK_CXNB, 0, logfd) == -1)
                return -1;
@@ -883,27 +884,15 @@ ps_root_start(struct dhcpcd_ctx *ctx)
                return -1;
 #endif
 
-       pc = calloc(1, sizeof(*pc));
-       if (pc == NULL)
-               return -1;
-       pc->psr_ctx = ctx;
-
        psp = ctx->ps_root = ps_newprocess(ctx, &id);
        if (psp == NULL)
-       {
-               free(pc);
                return -1;
-       }
-       psp->psp_freedata = ps_root_freepsdata;
+
        strlcpy(psp->psp_name, "privileged proxy", sizeof(psp->psp_name));
        pid = ps_startprocess(psp, ps_root_recvmsg, NULL,
            ps_root_startcb, PSF_ELOOP);
-       if (pid == -1) {
-               free(pc);
+       if (pid == -1)
                return -1;
-       }
-
-       psp->psp_data = pc;
 
        if (pid == 0) {
                ctx->ps_log_fd = logfd[0]; /* Keep open to pass to processes */
index 7a81c196671359fd9996b53b6eb197349320be53..1cb2dd77aaa4e457915b4f688d1f604ecee17fb8 100644 (file)
@@ -761,11 +761,6 @@ ps_freeprocess(struct ps_process *psp)
 
        TAILQ_REMOVE(&ctx->ps_processes, psp, next);
 
-       if (psp->psp_freedata != NULL)
-               psp->psp_freedata(psp->psp_data);
-       else
-               free(psp->psp_data);
-
        if (psp->psp_fd != -1) {
                eloop_event_delete(ctx->eloop, psp->psp_fd);
                close(psp->psp_fd);
index 37380d4c8b8251be5650ab6cddffeee877b70e35..496c9cd555d4b60e097b91d5ace9c9ecd37c8733 100644 (file)
@@ -184,8 +184,6 @@ struct ps_process {
        char psp_name[PSP_NAMESIZE];
        uint16_t psp_proto;
        const char *psp_protostr;
-       void *psp_data;
-       void (*psp_freedata)(void *);
        bool psp_started;
 
 #ifdef INET