]> git.ipfire.org Git - thirdparty/dhcpcd.git/commitdiff
Linux: keep the generic netlink socket around to get ssid with privsep
authorRoy Marples <roy@marples.name>
Mon, 22 Jun 2020 20:56:16 +0000 (21:56 +0100)
committerRoy Marples <roy@marples.name>
Mon, 22 Jun 2020 20:56:16 +0000 (21:56 +0100)
While here, improve our reading of netlink(7) and terminate on either
ERROR or DONE. If neither are in the message, read again unless it's
the link receiving socket.
Also, only callback if this is the sequence number expected.

src/if-linux.c

index b3ab82804331d1e77d12eb2cbcb81495928bfa39..815a06b713fc0f08a7362863e27350098a39a67c 100644 (file)
@@ -130,6 +130,7 @@ int if_getssid_wext(const char *ifname, uint8_t *ssid);
 
 struct priv {
        int route_fd;
+       int generic_fd;
        uint32_t route_pid;
 };
 
@@ -414,6 +415,12 @@ if_opensockets_os(struct dhcpcd_ctx *ctx)
        if (getsockname(priv->route_fd, (struct sockaddr *)&snl, &len) == -1)
                return -1;
        priv->route_pid = snl.nl_pid;
+
+       memset(&snl, 0, sizeof(snl));
+       priv->generic_fd = if_linksocket(&snl, NETLINK_GENERIC, 0);
+       if (priv->generic_fd == -1)
+               return -1;
+
        return 0;
 }
 
@@ -425,6 +432,7 @@ if_closesockets_os(struct dhcpcd_ctx *ctx)
        if (ctx->priv != NULL) {
                priv = (struct priv *)ctx->priv;
                close(priv->route_fd);
+               close(priv->generic_fd);
        }
 }
 
@@ -465,26 +473,27 @@ if_getnetlink(struct dhcpcd_ctx *ctx, struct iovec *iov, int fd, int flags,
        };
        ssize_t len;
        struct nlmsghdr *nlm;
-       int r;
+       int r = 0;
        unsigned int again;
+       bool terminated;
 
 recv_again:
-       if ((len = recvmsg(fd, &msg, flags)) == -1)
-               return -1;
-       if (len == 0)
-               return 0;
+       len = recvmsg(fd, &msg, flags);
+       if (len == -1 || len == 0)
+               return (int)len;
 
        /* Check sender */
        if (msg.msg_namelen != sizeof(nladdr)) {
                errno = EINVAL;
                return -1;
        }
+
        /* Ignore message if it is not from kernel */
        if (nladdr.nl_pid != 0)
                return 0;
 
-       r = 0;
        again = 0;
+       terminated = false;
        for (nlm = iov->iov_base;
             nlm && NLMSG_OK(nlm, (size_t)len);
             nlm = NLMSG_NEXT(nlm, len))
@@ -492,6 +501,7 @@ recv_again:
                again = (nlm->nlmsg_flags & NLM_F_MULTI);
                if (nlm->nlmsg_type == NLMSG_NOOP)
                        continue;
+
                if (nlm->nlmsg_type == NLMSG_ERROR) {
                        struct nlmsgerr *err;
 
@@ -504,17 +514,21 @@ recv_again:
                                errno = -err->error;
                                return -1;
                        }
+                       again = 0;
+                       terminated = true;
                        break;
                }
                if (nlm->nlmsg_type == NLMSG_DONE) {
                        again = 0;
+                       terminated = true;
                        break;
                }
-               if (cb != NULL && (r = cb(ctx, cbarg, nlm)) != 0)
-                       break;
+               if (cb != NULL &&
+                  (nlm->nlmsg_seq == (uint32_t)ctx->seq || fd == ctx->link_fd))
+                       r = cb(ctx, cbarg, nlm);
        }
 
-       if (r == 0 && again)
+       if ((again || !terminated) && (ctx != NULL && ctx->link_fd != fd))
                goto recv_again;
 
        return r;
@@ -982,16 +996,19 @@ static int
 if_sendnetlink(struct dhcpcd_ctx *ctx, int protocol, struct nlmsghdr *hdr,
     int (*cb)(struct dhcpcd_ctx *, void *, struct nlmsghdr *), void *cbarg)
 {
-       int s, r;
+       int s;
        struct sockaddr_nl snl = { .nl_family = AF_NETLINK };
        struct iovec iov = { .iov_base = hdr, .iov_len = hdr->nlmsg_len };
        struct msghdr msg = {
            .msg_name = &snl, .msg_namelen = sizeof(snl),
            .msg_iov = &iov, .msg_iovlen = 1
        };
-       bool use_rfd;
-
-       use_rfd = (protocol == NETLINK_ROUTE && hdr->nlmsg_type != RTM_GETADDR);
+       struct priv *priv = (struct priv *)ctx->priv;
+       unsigned char buf[16 * 1024];
+       struct iovec riov = {
+               .iov_base = buf,
+               .iov_len = sizeof(buf),
+       };
 
        /* Request a reply */
        hdr->nlmsg_flags |= NLM_F_ACK;
@@ -1002,13 +1019,16 @@ if_sendnetlink(struct dhcpcd_ctx *ctx, int protocol, struct nlmsghdr *hdr,
                return (int)ps_root_sendnetlink(ctx, protocol, &msg);
 #endif
 
-       if (use_rfd) {
-               struct priv *priv = (struct priv *)ctx->priv;
-
-               s = priv->route_fd;
-       } else {
-               if ((s = if_linksocket(&snl, protocol, 0)) == -1)
-                       return -1;
+       switch (protocol) {
+       case NETLINK_ROUTE:
+               if (hdr->nlmsg_type != RTM_GETADDR) {
+                       s = priv->route_fd;
+                       break;
+               }
+               /* FALLTHROUGH */
+       case NETLINK_GENERIC:
+               s = priv->generic_fd;
+#if 0
 #ifdef NETLINK_GET_STRICT_CHK
                if (hdr->nlmsg_type == RTM_GETADDR) {
                        int on = 1;
@@ -1018,22 +1038,17 @@ if_sendnetlink(struct dhcpcd_ctx *ctx, int protocol, struct nlmsghdr *hdr,
                                logerr("%s: NETLINK_GET_STRICT_CHK", __func__);
                }
 #endif
+#endif
+               break;
+       default:
+               errno = EINVAL;
+               return -1;
        }
 
-       if (sendmsg(s, &msg, 0) != -1) {
-               unsigned char buf[16 * 1024];
-               struct iovec riov = {
-                       .iov_base = buf,
-                       .iov_len = sizeof(buf),
-               };
-
-               r = if_getnetlink(ctx, &riov, s, 0, cb, cbarg);
-       } else
-               r = -1;
+       if (sendmsg(s, &msg, 0) == -1)
+               return -1;
 
-       if (!use_rfd)
-               close(s);
-       return r;
+       return if_getnetlink(ctx, &riov, s, 0, cb, cbarg);
 }
 
 #define NLMSG_TAIL(nmsg)                                               \