]> git.ipfire.org Git - thirdparty/wireguard-tools.git/commitdiff
wg: ipc: read from socket incrementally
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Jan 2017 03:50:42 +0000 (04:50 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Jan 2017 04:36:43 +0000 (05:36 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
src/Makefile
src/ipc.c

index fee7951a280574672b0620d8fc920a8c9170320c..6502c3d47b1c8ba67c555eadbf5a678c29decc08 100644 (file)
@@ -33,7 +33,7 @@ endif
 
 CFLAGS ?= -O3
 CFLAGS += -std=gnu11
-CFLAGS += -pedantic -Wall -Wextra
+CFLAGS += -Wall -Wextra
 CFLAGS += -MMD -MP
 CFLAGS += -DRUNSTATEDIR="\"$(RUNSTATEDIR)\""
 LDLIBS += -lresolv
index 623796115c73170ddad916250109460e07cc11c7..05609b4b5ac53c250014e66a10e0f243bfdd2578 100644 (file)
--- a/src/ipc.c
+++ b/src/ipc.c
@@ -18,7 +18,6 @@
 #include <unistd.h>
 #include <time.h>
 #include <dirent.h>
-#include <poll.h>
 #include <signal.h>
 #include <sys/socket.h>
 #include <sys/types.h>
@@ -41,7 +40,7 @@ struct inflatable_buffer {
        size_t pos;
 };
 
-#define max(a, b) (a > b ? a : b)
+#define max(a, b) ((a) > (b) ? (a) : (b))
 
 static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer)
 {
@@ -190,68 +189,75 @@ out:
        return (int)ret;
 }
 
+#define READ_BYTES(bytes) ({ \
+       void *__p; \
+       size_t __bytes = (bytes); \
+       if (bytes_left < __bytes) { \
+               offset = p - buffer; \
+               bytes_left += buffer_size; \
+               buffer_size *= 2; \
+               ret = -ENOMEM; \
+               p = realloc(buffer, buffer_size); \
+               if (!p) \
+                       goto out; \
+               buffer = p; \
+               p += offset; \
+       } \
+       bytes_left -= __bytes; \
+       ret = read(fd, p, __bytes); \
+       if (ret < 0) \
+               goto out; \
+       if ((size_t)ret != __bytes) { \
+               ret = -EBADMSG; \
+               goto out; \
+       } \
+       __p = p; \
+       p += __bytes; \
+       __p; \
+})
 static int userspace_get_device(struct wgdevice **dev, const char *interface)
 {
-       struct pollfd pollfd = { .events = POLLIN };
-       int len;
-       char byte = 0;
-       size_t i;
-       struct wgpeer *peer;
+       unsigned int len = 0, i;
+       size_t buffer_size, bytes_left;
        ssize_t ret;
+       ptrdiff_t offset;
+       uint8_t *buffer = NULL, *p, byte = 0;
+
        int fd = userspace_interface_fd(interface);
        if (fd < 0)
                return fd;
-       *dev = NULL;
+
        ret = write(fd, &byte, sizeof(byte));
        if (ret < 0)
                goto out;
-
-       pollfd.fd = fd;
-       if (poll(&pollfd, 1, -1) < 0)
-               goto out;
-       ret = -ECONNABORTED;
-       if (!(pollfd.revents & POLLIN))
-               goto out;
-
-       ret = ioctl(fd, FIONREAD, &len);
-       if (ret < 0) {
-               ret = -errno;
+       if (ret != sizeof(byte)) {
+               ret = -EBADMSG;
                goto out;
        }
-       ret = -EBADMSG;
-       if ((size_t)len < sizeof(struct wgdevice))
-               goto out;
 
+       ioctl(fd, FIONREAD, &len);
+       bytes_left = buffer_size = max(len, sizeof(struct wgdevice) + sizeof(struct wgpeer) + sizeof(struct wgipmask));
+       p = buffer = malloc(buffer_size);
        ret = -ENOMEM;
-       *dev = malloc(len);
-       if (!*dev)
+       if (!buffer)
                goto out;
 
-       ret = read(fd, *dev, len);
-       if (ret < 0)
-               goto out;
-       if (ret != len) {
-               ret = -EBADMSG;
-               goto out;
-       }
-
-       ret = -EBADMSG;
-       for_each_wgpeer(*dev, peer, i) {
-               if ((uint8_t *)peer + sizeof(struct wgpeer) > (uint8_t *)*dev + len)
-                       goto out;
-               if ((uint8_t *)peer + sizeof(struct wgpeer) + sizeof(struct wgipmask) * peer->num_ipmasks > (uint8_t *)*dev + len)
-               goto out;
-       }
+       len = ((struct wgdevice *)READ_BYTES(sizeof(struct wgdevice)))->num_peers;
+       for (i = 0; i < len; ++i)
+               READ_BYTES(sizeof(struct wgipmask) * ((struct wgpeer *)READ_BYTES(sizeof(struct wgpeer)))->num_ipmasks);
        ret = 0;
 out:
-       if (*dev && ret) {
-               free(*dev);
-               *dev = NULL;
+       if (buffer && ret) {
+               free(buffer);
+               buffer = NULL;
        }
+       *dev = (struct wgdevice *)buffer;
        close(fd);
        errno = -ret;
        return ret;
+
 }
+#undef READ_BYTES
 
 #ifdef __linux__
 static int parse_linkinfo(const struct nlattr *attr, void *data)