]> git.ipfire.org Git - thirdparty/openssh-portable.git/commitdiff
upstream commit
authordjm@openbsd.org <djm@openbsd.org>
Wed, 19 Jul 2017 01:15:02 +0000 (01:15 +0000)
committerDamien Miller <djm@mindrot.org>
Fri, 21 Jul 2017 04:17:33 +0000 (14:17 +1000)
switch from select() to poll() for the ssh-agent
mainloop; ok markus

Upstream-ID: 4a94888ee67b3fd948fd10693973beb12f802448

ssh-agent.c

index eb8c2043df2fba405e505017bc7a76fa7e6d779f..d858c24701a730c5d939133bca354144817d3b72 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssh-agent.c,v 1.222 2017/07/01 13:50:45 djm Exp $ */
+/* $OpenBSD: ssh-agent.c,v 1.223 2017/07/19 01:15:02 djm Exp $ */
 /*
  * Author: Tatu Ylonen <ylo@cs.hut.fi>
  * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -60,6 +60,9 @@
 #ifdef HAVE_PATHS_H
 # include <paths.h>
 #endif
+#ifdef HAVE_POLL_H
+# include <poll.h>
+#endif
 #include <signal.h>
 #include <stdarg.h>
 #include <stdio.h>
@@ -91,6 +94,9 @@
 # define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*"
 #endif
 
+/* Maximum accepted message length */
+#define AGENT_MAX_LEN  (256*1024)
+
 typedef enum {
        AUTH_UNUSED,
        AUTH_SOCKET,
@@ -634,30 +640,46 @@ send:
 
 /* dispatch incoming messages */
 
-static void
-process_message(SocketEntry *e)
+static int
+process_message(u_int socknum)
 {
        u_int msg_len;
        u_char type;
        const u_char *cp;
        int r;
+       SocketEntry *e;
+
+       if (socknum >= sockets_alloc) {
+               fatal("%s: socket number %u >= allocated %u",
+                   __func__, socknum, sockets_alloc);
+       }
+       e = &sockets[socknum];
 
        if (sshbuf_len(e->input) < 5)
-               return;         /* Incomplete message. */
+               return 0;               /* Incomplete message header. */
        cp = sshbuf_ptr(e->input);
        msg_len = PEEK_U32(cp);
-       if (msg_len > 256 * 1024) {
-               close_socket(e);
-               return;
+       if (msg_len > AGENT_MAX_LEN) {
+               debug("%s: socket %u (fd=%d) message too long %u > %u",
+                   __func__, socknum, e->fd, msg_len, AGENT_MAX_LEN);
+               return -1;
        }
        if (sshbuf_len(e->input) < msg_len + 4)
-               return;
+               return 0;               /* Incomplete message body. */
 
        /* move the current input to e->request */
        sshbuf_reset(e->request);
        if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 ||
-           (r = sshbuf_get_u8(e->request, &type)) != 0)
+           (r = sshbuf_get_u8(e->request, &type)) != 0) {
+               if (r == SSH_ERR_MESSAGE_INCOMPLETE ||
+                   r == SSH_ERR_STRING_TOO_LARGE) {
+                       debug("%s: buffer error: %s", __func__, ssh_err(r));
+                       return -1;
+               }
                fatal("%s: buffer error: %s", __func__, ssh_err(r));
+       }
+
+       debug("%s: socket %u (fd=%d) type %d", __func__, socknum, e->fd, type);
 
        /* check wheter agent is locked */
        if (locked && type != SSH_AGENTC_UNLOCK) {
@@ -671,10 +693,9 @@ process_message(SocketEntry *e)
                        /* send a fail message for all other request types */
                        send_status(e, 0);
                }
-               return;
+               return 0;
        }
 
-       debug("type %d", type);
        switch (type) {
        case SSH_AGENTC_LOCK:
        case SSH_AGENTC_UNLOCK:
@@ -716,6 +737,7 @@ process_message(SocketEntry *e)
                send_status(e, 0);
                break;
        }
+       return 0;
 }
 
 static void
@@ -757,19 +779,141 @@ new_socket(sock_type type, int fd)
 }
 
 static int
-prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
-    struct timeval **tvpp)
+handle_socket_read(u_int socknum)
+{
+       struct sockaddr_un sunaddr;
+       socklen_t slen;
+       uid_t euid;
+       gid_t egid;
+       int fd;
+
+       slen = sizeof(sunaddr);
+       fd = accept(sockets[socknum].fd, (struct sockaddr *)&sunaddr, &slen);
+       if (fd < 0) {
+               error("accept from AUTH_SOCKET: %s", strerror(errno));
+               return -1;
+       }
+       if (getpeereid(fd, &euid, &egid) < 0) {
+               error("getpeereid %d failed: %s", fd, strerror(errno));
+               close(fd);
+               return -1;
+       }
+       if ((euid != 0) && (getuid() != euid)) {
+               error("uid mismatch: peer euid %u != uid %u",
+                   (u_int) euid, (u_int) getuid());
+               close(fd);
+               return -1;
+       }
+       new_socket(AUTH_CONNECTION, fd);
+       return 0;
+}
+
+static int
+handle_conn_read(u_int socknum)
+{
+       char buf[1024];
+       ssize_t len;
+       int r;
+
+       if ((len = read(sockets[socknum].fd, buf, sizeof(buf))) <= 0) {
+               if (len == -1) {
+                       if (errno == EAGAIN || errno == EINTR)
+                               return 0;
+                       error("%s: read error on socket %u (fd %d): %s",
+                           __func__, socknum, sockets[socknum].fd,
+                           strerror(errno));
+               }
+               return -1;
+       }
+       if ((r = sshbuf_put(sockets[socknum].input, buf, len)) != 0)
+               fatal("%s: buffer error: %s", __func__, ssh_err(r));
+       explicit_bzero(buf, sizeof(buf));
+       process_message(socknum);
+       return 0;
+}
+
+static int
+handle_conn_write(u_int socknum)
+{
+       ssize_t len;
+       int r;
+
+       if (sshbuf_len(sockets[socknum].output) == 0)
+               return 0; /* shouldn't happen */
+       if ((len = write(sockets[socknum].fd,
+           sshbuf_ptr(sockets[socknum].output),
+           sshbuf_len(sockets[socknum].output))) <= 0) {
+               if (len == -1) {
+                       if (errno == EAGAIN || errno == EINTR)
+                               return 0;
+                       error("%s: read error on socket %u (fd %d): %s",
+                           __func__, socknum, sockets[socknum].fd,
+                           strerror(errno));
+               }
+               return -1;
+       }
+       if ((r = sshbuf_consume(sockets[socknum].output, len)) != 0)
+               fatal("%s: buffer error: %s", __func__, ssh_err(r));
+       return 0;
+}
+
+static void
+after_poll(struct pollfd *pfd, size_t npfd)
 {
-       u_int i, sz;
-       int n = 0;
-       static struct timeval tv;
+       size_t i;
+       u_int socknum;
+
+       for (i = 0; i < npfd; i++) {
+               if (pfd[i].revents == 0)
+                       continue;
+               /* Find sockets entry */
+               for (socknum = 0; socknum < sockets_alloc; socknum++) {
+                       if (sockets[socknum].type != AUTH_SOCKET &&
+                           sockets[socknum].type != AUTH_CONNECTION)
+                               continue;
+                       if (pfd[i].fd == sockets[socknum].fd)
+                               break;
+               }
+               if (socknum >= sockets_alloc) {
+                       error("%s: no socket for fd %d", __func__, pfd[i].fd);
+                       continue;
+               }
+               /* Process events */
+               switch (sockets[socknum].type) {
+               case AUTH_SOCKET:
+                       if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
+                           handle_socket_read(socknum) != 0)
+                               close_socket(&sockets[socknum]);
+                       break;
+               case AUTH_CONNECTION:
+                       if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
+                           handle_conn_read(socknum) != 0) {
+                               close_socket(&sockets[socknum]);
+                               break;
+                       }
+                       if ((pfd[i].revents & (POLLOUT|POLLHUP)) != 0 &&
+                           handle_conn_write(socknum) != 0)
+                               close_socket(&sockets[socknum]);
+                       break;
+               default:
+                       break;
+               }
+       }
+}
+
+static int
+prepare_poll(struct pollfd **pfdp, size_t *npfdp, int *timeoutp)
+{
+       struct pollfd *pfd = *pfdp;
+       size_t i, j, npfd = 0;
        time_t deadline;
 
+       /* Count active sockets */
        for (i = 0; i < sockets_alloc; i++) {
                switch (sockets[i].type) {
                case AUTH_SOCKET:
                case AUTH_CONNECTION:
-                       n = MAXIMUM(n, sockets[i].fd);
+                       npfd++;
                        break;
                case AUTH_UNUSED:
                        break;
@@ -778,28 +922,23 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
                        break;
                }
        }
+       if (npfd != *npfdp &&
+           (pfd = recallocarray(pfd, *npfdp, npfd, sizeof(*pfd))) == NULL)
+               fatal("%s: recallocarray failed", __func__);
+       *pfdp = pfd;
+       *npfdp = npfd;
 
-       sz = howmany(n+1, NFDBITS) * sizeof(fd_mask);
-       if (*fdrp == NULL || sz > *nallocp) {
-               free(*fdrp);
-               free(*fdwp);
-               *fdrp = xmalloc(sz);
-               *fdwp = xmalloc(sz);
-               *nallocp = sz;
-       }
-       if (n < *fdl)
-               debug("XXX shrink: %d < %d", n, *fdl);
-       *fdl = n;
-       memset(*fdrp, 0, sz);
-       memset(*fdwp, 0, sz);
-
-       for (i = 0; i < sockets_alloc; i++) {
+       for (i = j = 0; i < sockets_alloc; i++) {
                switch (sockets[i].type) {
                case AUTH_SOCKET:
                case AUTH_CONNECTION:
-                       FD_SET(sockets[i].fd, *fdrp);
+                       pfd[j].fd = sockets[i].fd;
+                       pfd[j].revents = 0;
+                       /* XXX backoff when input buffer full */
+                       pfd[j].events = POLLIN;
                        if (sshbuf_len(sockets[i].output) > 0)
-                               FD_SET(sockets[i].fd, *fdwp);
+                               pfd[j].events |= POLLOUT;
+                       j++;
                        break;
                default:
                        break;
@@ -810,98 +949,16 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
                deadline = (deadline == 0) ? parent_alive_interval :
                    MINIMUM(deadline, parent_alive_interval);
        if (deadline == 0) {
-               *tvpp = NULL;
+               *timeoutp = INFTIM;
        } else {
-               tv.tv_sec = deadline;
-               tv.tv_usec = 0;
-               *tvpp = &tv;
+               if (deadline > INT_MAX / 1000)
+                       *timeoutp = INT_MAX / 1000;
+               else
+                       *timeoutp = deadline * 1000;
        }
        return (1);
 }
 
-static void
-after_select(fd_set *readset, fd_set *writeset)
-{
-       struct sockaddr_un sunaddr;
-       socklen_t slen;
-       char buf[1024];
-       int len, sock, r;
-       u_int i, orig_alloc;
-       uid_t euid;
-       gid_t egid;
-
-       for (i = 0, orig_alloc = sockets_alloc; i < orig_alloc; i++)
-               switch (sockets[i].type) {
-               case AUTH_UNUSED:
-                       break;
-               case AUTH_SOCKET:
-                       if (FD_ISSET(sockets[i].fd, readset)) {
-                               slen = sizeof(sunaddr);
-                               sock = accept(sockets[i].fd,
-                                   (struct sockaddr *)&sunaddr, &slen);
-                               if (sock < 0) {
-                                       error("accept from AUTH_SOCKET: %s",
-                                           strerror(errno));
-                                       break;
-                               }
-                               if (getpeereid(sock, &euid, &egid) < 0) {
-                                       error("getpeereid %d failed: %s",
-                                           sock, strerror(errno));
-                                       close(sock);
-                                       break;
-                               }
-                               if ((euid != 0) && (getuid() != euid)) {
-                                       error("uid mismatch: "
-                                           "peer euid %u != uid %u",
-                                           (u_int) euid, (u_int) getuid());
-                                       close(sock);
-                                       break;
-                               }
-                               new_socket(AUTH_CONNECTION, sock);
-                       }
-                       break;
-               case AUTH_CONNECTION:
-                       if (sshbuf_len(sockets[i].output) > 0 &&
-                           FD_ISSET(sockets[i].fd, writeset)) {
-                               len = write(sockets[i].fd,
-                                   sshbuf_ptr(sockets[i].output),
-                                   sshbuf_len(sockets[i].output));
-                               if (len == -1 && (errno == EAGAIN ||
-                                   errno == EWOULDBLOCK ||
-                                   errno == EINTR))
-                                       continue;
-                               if (len <= 0) {
-                                       close_socket(&sockets[i]);
-                                       break;
-                               }
-                               if ((r = sshbuf_consume(sockets[i].output,
-                                   len)) != 0)
-                                       fatal("%s: buffer error: %s",
-                                           __func__, ssh_err(r));
-                       }
-                       if (FD_ISSET(sockets[i].fd, readset)) {
-                               len = read(sockets[i].fd, buf, sizeof(buf));
-                               if (len == -1 && (errno == EAGAIN ||
-                                   errno == EWOULDBLOCK ||
-                                   errno == EINTR))
-                                       continue;
-                               if (len <= 0) {
-                                       close_socket(&sockets[i]);
-                                       break;
-                               }
-                               if ((r = sshbuf_put(sockets[i].input,
-                                   buf, len)) != 0)
-                                       fatal("%s: buffer error: %s",
-                                           __func__, ssh_err(r));
-                               explicit_bzero(buf, sizeof(buf));
-                               process_message(&sockets[i]);
-                       }
-                       break;
-               default:
-                       fatal("Unknown type %d", sockets[i].type);
-               }
-}
-
 static void
 cleanup_socket(void)
 {
@@ -963,7 +1020,6 @@ main(int ac, char **av)
        int sock, fd, ch, result, saved_errno;
        u_int nalloc;
        char *shell, *format, *pidstr, *agentsocket = NULL;
-       fd_set *readsetp = NULL, *writesetp = NULL;
 #ifdef HAVE_SETRLIMIT
        struct rlimit rlim;
 #endif
@@ -971,9 +1027,11 @@ main(int ac, char **av)
        extern char *optarg;
        pid_t pid;
        char pidstrbuf[1 + 3 * sizeof pid];
-       struct timeval *tvp = NULL;
        size_t len;
        mode_t prev_mask;
+       int timeout = INFTIM;
+       struct pollfd *pfd = NULL;
+       size_t npfd = 0;
 
        ssh_malloc_init();      /* must be called before any mallocs */
        /* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */
@@ -1201,8 +1259,8 @@ skip:
        platform_pledge_agent();
 
        while (1) {
-               prepare_select(&readsetp, &writesetp, &max_fd, &nalloc, &tvp);
-               result = select(max_fd + 1, readsetp, writesetp, NULL, tvp);
+               prepare_poll(&pfd, &npfd, &timeout);
+               result = poll(pfd, npfd, timeout);
                saved_errno = errno;
                if (parent_alive_interval != 0)
                        check_parent_exists();
@@ -1210,9 +1268,9 @@ skip:
                if (result < 0) {
                        if (saved_errno == EINTR)
                                continue;
-                       fatal("select: %s", strerror(saved_errno));
+                       fatal("poll: %s", strerror(saved_errno));
                } else if (result > 0)
-                       after_select(readsetp, writesetp);
+                       after_poll(pfd, npfd);
        }
        /* NOTREACHED */
 }