]> git.ipfire.org Git - thirdparty/wireguard-tools.git/commitdiff
wg: store tail pointer to make coalescing peers fast
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Oct 2017 15:17:43 +0000 (17:17 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Oct 2017 15:19:01 +0000 (17:19 +0200)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
src/containers.h
src/ipc.c

index 4c80a770bdda809450649f53d4d5b9a409dc6e46..2d0195d8aeab9feb6197b469c3bf86468e91af19 100644 (file)
@@ -46,7 +46,7 @@ struct wgpeer {
        uint64_t rx_bytes, tx_bytes;
        uint16_t persistent_keepalive_interval;
 
-       struct wgallowedip *first_allowedip;
+       struct wgallowedip *first_allowedip, *last_allowedip;
        struct wgpeer *next_peer;
 };
 
@@ -73,7 +73,7 @@ struct wgdevice {
        uint32_t fwmark;
        uint16_t listen_port;
 
-       struct wgpeer *first_peer;
+       struct wgpeer *first_peer, *last_peer;
 };
 
 #define for_each_wgpeer(__dev, __peer) for ((__peer) = (__dev)->first_peer; (__peer); (__peer) = (__peer)->next_peer)
index 521a904320a2a1797536eb3b8322f66be3d36f2b..db1936211fbbecd1fc84fa8898efdee652d85103 100644 (file)
--- a/src/ipc.c
+++ b/src/ipc.c
@@ -678,30 +678,24 @@ out:
        return ret;
 }
 
-struct get_device_ctx {
-       struct wgdevice *device;
-       struct wgpeer *peer;
-       struct wgallowedip *allowedip;
-};
-
 static int parse_allowedip(const struct nlattr *attr, void *data)
 {
-       struct get_device_ctx *ctx = data;
+       struct wgallowedip *allowedip = data;
 
        switch (mnl_attr_get_type(attr)) {
        case WGALLOWEDIP_A_FAMILY:
                if (!mnl_attr_validate(attr, MNL_TYPE_U16))
-                       ctx->allowedip->family = mnl_attr_get_u16(attr);
+                       allowedip->family = mnl_attr_get_u16(attr);
                break;
        case WGALLOWEDIP_A_IPADDR:
-               if (mnl_attr_get_payload_len(attr) == sizeof(ctx->allowedip->ip4))
-                       memcpy(&ctx->allowedip->ip4, mnl_attr_get_payload(attr), sizeof(ctx->allowedip->ip4));
-               else if (mnl_attr_get_payload_len(attr) == sizeof(ctx->allowedip->ip6))
-                       memcpy(&ctx->allowedip->ip6, mnl_attr_get_payload(attr), sizeof(ctx->allowedip->ip6));
+               if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4))
+                       memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4));
+               else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6))
+                       memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6));
                break;
        case WGALLOWEDIP_A_CIDR_MASK:
                if (!mnl_attr_validate(attr, MNL_TYPE_U8))
-                       ctx->allowedip->cidr = mnl_attr_get_u8(attr);
+                       allowedip->cidr = mnl_attr_get_u8(attr);
                break;
        default:
                warn_unrecognized("netlink");
@@ -712,68 +706,70 @@ static int parse_allowedip(const struct nlattr *attr, void *data)
 
 static int parse_allowedips(const struct nlattr *attr, void *data)
 {
-       struct get_device_ctx *ctx = data;
+       struct wgpeer *peer = data;
        struct wgallowedip *new_allowedip = calloc(1, sizeof(struct wgallowedip));
        int ret;
+
        if (!new_allowedip) {
                perror("calloc");
                return MNL_CB_ERROR;
        }
-       if (ctx->allowedip)
-               ctx->allowedip->next_allowedip = new_allowedip;
-       else
-               ctx->peer->first_allowedip = new_allowedip;
-       ctx->allowedip = new_allowedip;
-       ret = mnl_attr_parse_nested(attr, parse_allowedip, ctx);
+       if (!peer->first_allowedip)
+               peer->first_allowedip = peer->last_allowedip = new_allowedip;
+       else {
+               peer->last_allowedip->next_allowedip = new_allowedip;
+               peer->last_allowedip = new_allowedip;
+       }
+       ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip);
        if (!ret)
                return ret;
-       if (!((ctx->allowedip->family == AF_INET && ctx->allowedip->cidr <= 32) || (ctx->allowedip->family == AF_INET6 && ctx->allowedip->cidr <= 128)))
+       if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128)))
                return MNL_CB_ERROR;
        return MNL_CB_OK;
 }
 
 static int parse_peer(const struct nlattr *attr, void *data)
 {
-       struct get_device_ctx *ctx = data;
+       struct wgpeer *peer = data;
 
        switch (mnl_attr_get_type(attr)) {
        case WGPEER_A_PUBLIC_KEY:
-               if (mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->public_key))
-                       memcpy(ctx->peer->public_key, mnl_attr_get_payload(attr), sizeof(ctx->peer->public_key));
+               if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key))
+                       memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key));
                break;
        case WGPEER_A_PRESHARED_KEY:
-               if (mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->preshared_key))
-                       memcpy(ctx->peer->preshared_key, mnl_attr_get_payload(attr), sizeof(ctx->peer->preshared_key));
+               if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key))
+                       memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key));
                break;
        case WGPEER_A_ENDPOINT: {
                struct sockaddr *addr;
                if (mnl_attr_get_payload_len(attr) < sizeof(*addr))
                        break;
                addr = mnl_attr_get_payload(attr);
-               if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->endpoint.addr4))
-                       memcpy(&ctx->peer->endpoint.addr4, addr, sizeof(ctx->peer->endpoint.addr4));
-               else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->endpoint.addr6))
-                       memcpy(&ctx->peer->endpoint.addr6, addr, sizeof(ctx->peer->endpoint.addr6));
+               if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4))
+                       memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4));
+               else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6))
+                       memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6));
                break;
        }
        case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL:
                if (!mnl_attr_validate(attr, MNL_TYPE_U16))
-                       ctx->peer->persistent_keepalive_interval = mnl_attr_get_u16(attr);
+                       peer->persistent_keepalive_interval = mnl_attr_get_u16(attr);
                break;
        case WGPEER_A_LAST_HANDSHAKE_TIME:
-               if (mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->last_handshake_time))
-                       memcpy(&ctx->peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(ctx->peer->last_handshake_time));
+               if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time))
+                       memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time));
                break;
        case WGPEER_A_RX_BYTES:
                if (!mnl_attr_validate(attr, MNL_TYPE_U64))
-                       ctx->peer->rx_bytes = mnl_attr_get_u64(attr);
+                       peer->rx_bytes = mnl_attr_get_u64(attr);
                break;
        case WGPEER_A_TX_BYTES:
                if (!mnl_attr_validate(attr, MNL_TYPE_U64))
-                       ctx->peer->tx_bytes = mnl_attr_get_u64(attr);
+                       peer->tx_bytes = mnl_attr_get_u64(attr);
                break;
        case WGPEER_A_ALLOWEDIPS:
-               return mnl_attr_parse_nested(attr, parse_allowedips, ctx);
+               return mnl_attr_parse_nested(attr, parse_allowedips, peer);
        default:
                warn_unrecognized("netlink");
        }
@@ -783,58 +779,59 @@ static int parse_peer(const struct nlattr *attr, void *data)
 
 static int parse_peers(const struct nlattr *attr, void *data)
 {
-       struct get_device_ctx *ctx = data;
+       struct wgdevice *device = data;
        struct wgpeer *new_peer = calloc(1, sizeof(struct wgpeer));
        int ret;
+
        if (!new_peer) {
                perror("calloc");
                return MNL_CB_ERROR;
        }
-       if (ctx->peer)
-               ctx->peer->next_peer = new_peer;
-       else
-               ctx->device->first_peer = new_peer;
-       ctx->peer = new_peer;
-       ctx->allowedip = NULL;
-       ret = mnl_attr_parse_nested(attr, parse_peer, ctx);
+       if (!device->first_peer)
+               device->first_peer = device->last_peer = new_peer;
+       else {
+               device->last_peer->next_peer = new_peer;
+               device->last_peer = new_peer;
+       }
+       ret = mnl_attr_parse_nested(attr, parse_peer, new_peer);
        if (!ret)
                return ret;
-       if (key_is_zero(ctx->peer->public_key))
+       if (key_is_zero(new_peer->public_key))
                return MNL_CB_ERROR;
        return MNL_CB_OK;
 }
 
 static int parse_device(const struct nlattr *attr, void *data)
 {
-       struct get_device_ctx *ctx = data;
+       struct wgdevice *device = data;
 
        switch (mnl_attr_get_type(attr)) {
        case WGDEVICE_A_IFINDEX:
                if (!mnl_attr_validate(attr, MNL_TYPE_U32))
-                       ctx->device->ifindex = mnl_attr_get_u32(attr);
+                       device->ifindex = mnl_attr_get_u32(attr);
                break;
        case WGDEVICE_A_IFNAME:
                if (!mnl_attr_validate(attr, MNL_TYPE_STRING))
-                       strncpy(ctx->device->name, mnl_attr_get_str(attr), sizeof(ctx->device->name) - 1);
+                       strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1);
                break;
        case WGDEVICE_A_PRIVATE_KEY:
-               if (mnl_attr_get_payload_len(attr) == sizeof(ctx->device->private_key))
-                       memcpy(ctx->device->private_key, mnl_attr_get_payload(attr), sizeof(ctx->device->private_key));
+               if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key))
+                       memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key));
                break;
        case WGDEVICE_A_PUBLIC_KEY:
-               if (mnl_attr_get_payload_len(attr) == sizeof(ctx->device->public_key))
-                       memcpy(ctx->device->public_key, mnl_attr_get_payload(attr), sizeof(ctx->device->public_key));
+               if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key))
+                       memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key));
                break;
        case WGDEVICE_A_LISTEN_PORT:
                if (!mnl_attr_validate(attr, MNL_TYPE_U16))
-                       ctx->device->listen_port = mnl_attr_get_u16(attr);
+                       device->listen_port = mnl_attr_get_u16(attr);
                break;
        case WGDEVICE_A_FWMARK:
                if (!mnl_attr_validate(attr, MNL_TYPE_U32))
-                       ctx->device->fwmark = mnl_attr_get_u32(attr);
+                       device->fwmark = mnl_attr_get_u32(attr);
                break;
        case WGDEVICE_A_PEERS:
-               return mnl_attr_parse_nested(attr, parse_peers, ctx);
+               return mnl_attr_parse_nested(attr, parse_peers, device);
        default:
                warn_unrecognized("netlink");
        }
@@ -849,42 +846,41 @@ static int read_device_cb(const struct nlmsghdr *nlh, void *data)
 
 static void coalesce_peers(struct wgdevice *device)
 {
-       struct wgallowedip *allowedip;
        struct wgpeer *old_next_peer, *peer = device->first_peer;
+
        while (peer && peer->next_peer) {
                if (memcmp(peer->public_key, peer->next_peer->public_key, WG_KEY_LEN)) {
                        peer = peer->next_peer;
                        continue;
                }
-               /* TODO: It would be more efficient to store the tail, rather than having to seek to the end each time. */
-               for (allowedip = peer->first_allowedip; allowedip && allowedip->next_allowedip; allowedip = allowedip->next_allowedip);
-
-               if (!allowedip)
+               if (!peer->first_allowedip) {
                        peer->first_allowedip = peer->next_peer->first_allowedip;
-               else
-                       allowedip->next_allowedip = peer->next_peer->first_allowedip;
+                       peer->last_allowedip = peer->next_peer->last_allowedip;
+               } else {
+                       peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip;
+                       peer->last_allowedip = peer->next_peer->last_allowedip;
+               }
                old_next_peer = peer->next_peer;
                peer->next_peer = old_next_peer->next_peer;
                free(old_next_peer);
        }
 }
 
-static int kernel_get_device(struct wgdevice **dev, const char *interface)
+static int kernel_get_device(struct wgdevice **device, const char *interface)
 {
        int ret = 0;
        struct nlmsghdr *nlh;
        struct mnlg_socket *nlg;
-       struct get_device_ctx ctx = { 0 };
 
 try_again:
-       *dev = ctx.device = calloc(1, sizeof(struct wgdevice));
-       if (!*dev)
+       *device = calloc(1, sizeof(struct wgdevice));
+       if (!*device)
                return -errno;
 
        nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
        if (!nlg) {
-               free_wgdevice(*dev);
-               *dev = NULL;
+               free_wgdevice(*device);
+               *device = NULL;
                return -errno;
        }
 
@@ -895,20 +891,20 @@ try_again:
                goto out;
        }
        errno = 0;
-       if (mnlg_socket_recv_run(nlg, read_device_cb, &ctx) < 0) {
+       if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) {
                ret = errno ? -errno : -EINVAL;
                goto out;
        }
-       coalesce_peers(*dev);
+       coalesce_peers(*device);
 
 out:
        if (nlg)
                mnlg_socket_close(nlg);
        if (ret) {
-               free_wgdevice(*dev);
+               free_wgdevice(*device);
                if (ret == -EINTR)
                        goto try_again;
-               *dev = NULL;
+               *device = NULL;
        }
        errno = -ret;
        return ret;