]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/network/netdev/wireguard.c
network: use netlink_message_append_{in_addr,sockaddr}_union()
[thirdparty/systemd.git] / src / network / netdev / wireguard.c
index 23db87a0d5f5c4484db85ec4e9c3db3f1802910c..7d35afae6d74c7d9c0ebcdaba683c9c85488a4e5 100644 (file)
@@ -1,36 +1,25 @@
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
-    This file is part of systemd.
-
-    Copyright 2016-2017 Jörg Thalheim <joerg@thalheim.io>
-    Copyright 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
-
-    systemd is free software; you can redistribute it and/or modify it
-    under the terms of the GNU Lesser General Public License as published by
-    the Free Software Foundation; either version 2.1 of the License, or
-    (at your option) any later version.
-
-    systemd is distributed in the hope that it will be useful, but
-    WITHOUT ANY WARRANTY; without even the implied warranty of
-    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-    Lesser General Public License for more details.
-
-    You should have received a copy of the GNU Lesser General Public License
-    along with systemd; If not, see <http://www.gnu.org/licenses/>.
+  Copyright © 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
 ***/
 
 #include <sys/ioctl.h>
 #include <net/if.h>
 
+#include "sd-resolve.h"
+
 #include "alloc-util.h"
-#include "parse-util.h"
 #include "fd-util.h"
-#include "strv.h"
 #include "hexdecoct.h"
-#include "string-util.h"
-#include "wireguard.h"
+#include "netlink-util.h"
 #include "networkd-link.h"
-#include "networkd-util.h"
 #include "networkd-manager.h"
+#include "networkd-util.h"
+#include "parse-util.h"
+#include "resolve-private.h"
+#include "string-util.h"
+#include "strv.h"
+#include "wireguard.h"
 #include "wireguard-netlink.h"
 
 static void resolve_endpoints(NetDev *netdev);
@@ -43,10 +32,13 @@ static WireguardPeer *wireguard_peer_new(Wireguard *w, unsigned section) {
         if (w->last_peer_section == section && w->peers)
                 return w->peers;
 
-        peer = new0(WireguardPeer, 1);
+        peer = new(WireguardPeer, 1);
         if (!peer)
                 return NULL;
-        peer->flags = WGPEER_F_REPLACE_ALLOWEDIPS;
+
+        *peer = (WireguardPeer) {
+                .flags = WGPEER_F_REPLACE_ALLOWEDIPS,
+        };
 
         LIST_PREPEND(peers, w->peers, peer);
         w->last_peer_section = section;
@@ -54,22 +46,133 @@ static WireguardPeer *wireguard_peer_new(Wireguard *w, unsigned section) {
         return peer;
 }
 
-static int set_wireguard_interface(NetDev *netdev) {
+static int wireguard_set_ipmask_one(NetDev *netdev, sd_netlink_message *message, const WireguardIPmask *mask, uint16_t index) {
         int r;
-        unsigned int i, j;
-        WireguardPeer *peer, *peer_start;
-        WireguardIPmask *mask, *mask_start = NULL;
+
+        assert(message);
+        assert(mask);
+        assert(index > 0);
+
+        /* This returns 1 on success, 0 on recoverable error, and negative errno on failure. */
+
+        r = sd_netlink_message_open_array(message, index);
+        if (r < 0)
+                return 0;
+
+        r = sd_netlink_message_append_u16(message, WGALLOWEDIP_A_FAMILY, mask->family);
+        if (r < 0)
+                goto cancel;
+
+        r = netlink_message_append_in_addr_union(message, WGALLOWEDIP_A_IPADDR, mask->family, &mask->ip);
+        if (r < 0)
+                goto cancel;
+
+        r = sd_netlink_message_append_u8(message, WGALLOWEDIP_A_CIDR_MASK, mask->cidr);
+        if (r < 0)
+                goto cancel;
+
+        r = sd_netlink_message_close_container(message);
+        if (r < 0)
+                return log_netdev_error_errno(netdev, r, "Could not add wireguard allowed ip: %m");
+
+        return 1;
+
+cancel:
+        r = sd_netlink_message_cancel_array(message);
+        if (r < 0)
+                return log_netdev_error_errno(netdev, r, "Could not cancel wireguard allowed ip message attribute: %m");
+
+        return 0;
+}
+
+static int wireguard_set_peer_one(NetDev *netdev, sd_netlink_message *message, const WireguardPeer *peer, uint16_t index, WireguardIPmask **mask_start) {
+        WireguardIPmask *mask, *start;
+        uint16_t j = 0;
+        int r;
+
+        assert(message);
+        assert(peer);
+        assert(index > 0);
+        assert(mask_start);
+
+        /* This returns 1 on success, 0 on recoverable error, and negative errno on failure. */
+
+        start = *mask_start ?: peer->ipmasks;
+
+        r = sd_netlink_message_open_array(message, index);
+        if (r < 0)
+                return 0;
+
+        r = sd_netlink_message_append_data(message, WGPEER_A_PUBLIC_KEY, &peer->public_key, sizeof(peer->public_key));
+        if (r < 0)
+                goto cancel;
+
+        if (!*mask_start) {
+                r = sd_netlink_message_append_data(message, WGPEER_A_PRESHARED_KEY, &peer->preshared_key, WG_KEY_LEN);
+                if (r < 0)
+                        goto cancel;
+
+                r = sd_netlink_message_append_u32(message, WGPEER_A_FLAGS, peer->flags);
+                if (r < 0)
+                        goto cancel;
+
+                r = sd_netlink_message_append_u16(message, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval);
+                if (r < 0)
+                        goto cancel;
+
+                if (IN_SET(peer->endpoint.sa.sa_family, AF_INET, AF_INET6)) {
+                        r = netlink_message_append_sockaddr_union(message, WGPEER_A_ENDPOINT, &peer->endpoint);
+                        if (r < 0)
+                                goto cancel;
+                }
+        }
+
+        r = sd_netlink_message_open_container(message, WGPEER_A_ALLOWEDIPS);
+        if (r < 0)
+                goto cancel;
+
+        LIST_FOREACH(ipmasks, mask, start) {
+                r = wireguard_set_ipmask_one(netdev, message, mask, ++j);
+                if (r < 0)
+                        return r;
+                if (r == 0)
+                        break;
+        }
+
+        r = sd_netlink_message_close_container(message);
+        if (r < 0)
+                return log_netdev_error_errno(netdev, r, "Could not add wireguard allowed ip: %m");
+
+        r = sd_netlink_message_close_container(message);
+        if (r < 0)
+                return log_netdev_error_errno(netdev, r, "Could not add wireguard peer: %m");
+
+        *mask_start = mask; /* Start next cycle from this mask. */
+        return !mask;
+
+cancel:
+        r = sd_netlink_message_cancel_array(message);
+        if (r < 0)
+                return log_netdev_error_errno(netdev, r, "Could not cancel wireguard peers: %m");
+
+        return 0;
+}
+
+static int wireguard_set_interface(NetDev *netdev) {
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *message = NULL;
-        Wireguard *w;
+        WireguardIPmask *mask_start = NULL;
+        WireguardPeer *peer, *peer_start;
         uint32_t serial;
+        Wireguard *w;
+        int r;
 
         assert(netdev);
         w = WIREGUARD(netdev);
         assert(w);
 
-        peer_start = w->peers;
+        for (peer_start = w->peers; peer_start; ) {
+                uint16_t i = 0;
 
-        do {
                 message = sd_netlink_message_unref(message);
 
                 r = sd_genl_message_new(netdev->manager->genl, SD_GENL_WIREGUARD, WG_CMD_SET_DEVICE, &message);
@@ -102,97 +205,14 @@ static int set_wireguard_interface(NetDev *netdev) {
                 if (r < 0)
                         return log_netdev_error_errno(netdev, r, "Could not append wireguard peer attributes: %m");
 
-                i = 0;
-
                 LIST_FOREACH(peers, peer, peer_start) {
-                        r = sd_netlink_message_open_array(message, ++i);
-                        if (r < 0)
-                                break;
-
-                        r = sd_netlink_message_append_data(message, WGPEER_A_PUBLIC_KEY, &peer->public_key, sizeof(peer->public_key));
+                        r = wireguard_set_peer_one(netdev, message, peer, ++i, &mask_start);
                         if (r < 0)
+                                return r;
+                        if (r == 0)
                                 break;
-
-                        if (!mask_start) {
-                                r = sd_netlink_message_append_data(message, WGPEER_A_PRESHARED_KEY, &peer->preshared_key, WG_KEY_LEN);
-                                if (r < 0)
-                                        break;
-
-                                r = sd_netlink_message_append_u32(message, WGPEER_A_FLAGS, peer->flags);
-                                if (r < 0)
-                                        break;
-
-                                r = sd_netlink_message_append_u32(message, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval);
-                                if (r < 0)
-                                        break;
-
-                                if (peer->endpoint.sa.sa_family == AF_INET) {
-                                        r = sd_netlink_message_append_data(message, WGPEER_A_ENDPOINT, &peer->endpoint.in, sizeof(peer->endpoint.in));
-                                        if (r < 0)
-                                                break;
-                                } else if (peer->endpoint.sa.sa_family == AF_INET6) {
-                                        r = sd_netlink_message_append_data(message, WGPEER_A_ENDPOINT, &peer->endpoint.in6, sizeof(peer->endpoint.in6));
-                                        if (r < 0)
-                                                break;
-                                }
-
-                                mask_start = peer->ipmasks;
-                        }
-
-                        r = sd_netlink_message_open_container(message, WGPEER_A_ALLOWEDIPS);
-                        if (r < 0) {
-                                mask_start = NULL;
-                                break;
-                        }
-                        j = 0;
-                        LIST_FOREACH(ipmasks, mask, mask_start) {
-                                r = sd_netlink_message_open_array(message, ++j);
-                                if (r < 0)
-                                        break;
-
-                                r = sd_netlink_message_append_u16(message, WGALLOWEDIP_A_FAMILY, mask->family);
-                                if (r < 0)
-                                        break;
-
-                                if (mask->family == AF_INET) {
-                                        r = sd_netlink_message_append_in_addr(message, WGALLOWEDIP_A_IPADDR, &mask->ip.in);
-                                        if (r < 0)
-                                                break;
-                                } else if (mask->family == AF_INET6) {
-                                        r = sd_netlink_message_append_in6_addr(message, WGALLOWEDIP_A_IPADDR, &mask->ip.in6);
-                                        if (r < 0)
-                                                break;
-                                }
-
-                                r = sd_netlink_message_append_u8(message, WGALLOWEDIP_A_CIDR_MASK, mask->cidr);
-                                if (r < 0)
-                                        break;
-
-                                r = sd_netlink_message_close_container(message);
-                                if (r < 0)
-                                        return log_netdev_error_errno(netdev, r, "Could not add wireguard allowed ip: %m");
-                        }
-                        mask_start = mask;
-                        if (mask_start) {
-                                r = sd_netlink_message_cancel_array(message);
-                                if (r < 0)
-                                        return log_netdev_error_errno(netdev, r, "Could not cancel wireguard allowed ip message attribute: %m");
-                        }
-                        r = sd_netlink_message_close_container(message);
-                        if (r < 0)
-                                return log_netdev_error_errno(netdev, r, "Could not add wireguard allowed ip: %m");
-
-                        r = sd_netlink_message_close_container(message);
-                        if (r < 0)
-                                return log_netdev_error_errno(netdev, r, "Could not add wireguard peer: %m");
-                }
-
-                peer_start = peer;
-                if (peer_start && !mask_start) {
-                        r = sd_netlink_message_cancel_array(message);
-                        if (r < 0)
-                                return log_netdev_error_errno(netdev, r, "Could not cancel wireguard peers: %m");
                 }
+                peer_start = peer; /* Start next cycle from this peer. */
 
                 r = sd_netlink_message_close_container(message);
                 if (r < 0)
@@ -201,8 +221,7 @@ static int set_wireguard_interface(NetDev *netdev) {
                 r = sd_netlink_send(netdev->manager->genl, message, &serial);
                 if (r < 0)
                         return log_netdev_error_errno(netdev, r, "Could not set wireguard device: %m");
-
-        } while (peer || mask_start);
+        }
 
         return 0;
 }
@@ -210,12 +229,19 @@ static int set_wireguard_interface(NetDev *netdev) {
 static WireguardEndpoint* wireguard_endpoint_free(WireguardEndpoint *e) {
         if (!e)
                 return NULL;
-        netdev_unref(e->netdev);
         e->host = mfree(e->host);
         e->port = mfree(e->port);
         return mfree(e);
 }
 
+static void wireguard_endpoint_destroy_callback(WireguardEndpoint *e) {
+        assert(e);
+        assert(e->netdev);
+
+        netdev_unref(e->netdev);
+        wireguard_endpoint_free(e);
+}
+
 DEFINE_TRIVIAL_CLEANUP_FUNC(WireguardEndpoint*, wireguard_endpoint_free);
 
 static int on_resolve_retry(sd_event_source *s, usec_t usec, void *userdata) {
@@ -226,8 +252,10 @@ static int on_resolve_retry(sd_event_source *s, usec_t usec, void *userdata) {
         w = WIREGUARD(netdev);
         assert(w);
 
-        w->resolve_retry_event_source = sd_event_source_unref(w->resolve_retry_event_source);
+        if (!netdev_is_managed(netdev))
+                return 0;
 
+        assert(!w->unresolved_endpoints);
         w->unresolved_endpoints = TAKE_PTR(w->failed_endpoints);
 
         resolve_endpoints(netdev);
@@ -246,28 +274,29 @@ static int exponential_backoff_milliseconds(unsigned n_retries) {
 static int wireguard_resolve_handler(sd_resolve_query *q,
                                      int ret,
                                      const struct addrinfo *ai,
-                                     void *userdata) {
+                                     WireguardEndpoint *e) {
+        _cleanup_(netdev_unrefp) NetDev *netdev_will_unrefed = NULL;
         NetDev *netdev;
         Wireguard *w;
-        _cleanup_(wireguard_endpoint_freep) WireguardEndpoint *e;
         int r;
 
-        assert(userdata);
-        e = userdata;
-        netdev = e->netdev;
+        assert(e);
+        assert(e->netdev);
 
-        assert(netdev);
+        netdev = e->netdev;
         w = WIREGUARD(netdev);
         assert(w);
 
-        w->resolve_query = sd_resolve_query_unref(w->resolve_query);
+        if (!netdev_is_managed(netdev))
+                return 0;
 
         if (ret != 0) {
                 log_netdev_error(netdev, "Failed to resolve host '%s:%s': %s", e->host, e->port, gai_strerror(ret));
                 LIST_PREPEND(endpoints, w->failed_endpoints, e);
-                e = NULL;
+                (void) sd_resolve_query_set_destroy_callback(q, NULL); /* Avoid freeing endpoint by destroy callback. */
+                netdev_will_unrefed = netdev; /* But netdev needs to be unrefed. */
         } else if ((ai->ai_family == AF_INET && ai->ai_addrlen == sizeof(struct sockaddr_in)) ||
-                        (ai->ai_family == AF_INET6 && ai->ai_addrlen == sizeof(struct sockaddr_in6)))
+                   (ai->ai_family == AF_INET6 && ai->ai_addrlen == sizeof(struct sockaddr_in6)))
                 memcpy(&e->peer->endpoint, ai->ai_addr, ai->ai_addrlen);
         else
                 log_netdev_error(netdev, "Neither IPv4 nor IPv6 address found for peer endpoint: %s:%s", e->host, e->port);
@@ -277,57 +306,74 @@ static int wireguard_resolve_handler(sd_resolve_query *q,
                 return 0;
         }
 
-        set_wireguard_interface(netdev);
+        (void) wireguard_set_interface(netdev);
         if (w->failed_endpoints) {
+                _cleanup_(sd_event_source_unrefp) sd_event_source *s = NULL;
+
                 w->n_retries++;
                 r = sd_event_add_time(netdev->manager->event,
-                                      &w->resolve_retry_event_source,
+                                      &s,
                                       CLOCK_MONOTONIC,
                                       now(CLOCK_MONOTONIC) + exponential_backoff_milliseconds(w->n_retries),
                                       0,
                                       on_resolve_retry,
                                       netdev);
-                if (r < 0)
+                if (r < 0) {
                         log_netdev_warning_errno(netdev, r, "Could not arm resolve retry handler: %m");
+                        return 0;
+                }
+
+                r = sd_event_source_set_destroy_callback(s, (sd_event_destroy_t) netdev_destroy_callback);
+                if (r < 0) {
+                        log_netdev_warning_errno(netdev, r, "Failed to set destroy callback to event source: %m");
+                        return 0;
+                }
+
+                (void) sd_event_source_set_floating(s, true);
+                netdev_ref(netdev);
         }
 
         return 0;
 }
 
 static void resolve_endpoints(NetDev *netdev) {
-        int r = 0;
-        Wireguard *w;
-        WireguardEndpoint *endpoint;
         static const struct addrinfo hints = {
                 .ai_family = AF_UNSPEC,
                 .ai_socktype = SOCK_DGRAM,
                 .ai_protocol = IPPROTO_UDP
         };
+        WireguardEndpoint *endpoint;
+        Wireguard *w;
+        int r = 0;
 
         assert(netdev);
         w = WIREGUARD(netdev);
         assert(w);
 
         LIST_FOREACH(endpoints, endpoint, w->unresolved_endpoints) {
-                r = sd_resolve_getaddrinfo(netdev->manager->resolve,
-                                           &w->resolve_query,
-                                           endpoint->host,
-                                           endpoint->port,
-                                           &hints,
-                                           wireguard_resolve_handler,
-                                           endpoint);
+                r = resolve_getaddrinfo(netdev->manager->resolve,
+                                        NULL,
+                                        endpoint->host,
+                                        endpoint->port,
+                                        &hints,
+                                        wireguard_resolve_handler,
+                                        wireguard_endpoint_destroy_callback,
+                                        endpoint);
 
                 if (r == -ENOBUFS)
                         break;
+                if (r < 0) {
+                        log_netdev_error_errno(netdev, r, "Failed to create resolver: %m");
+                        continue;
+                }
 
-                LIST_REMOVE(endpoints, w->unresolved_endpoints, endpoint);
+                /* Avoid freeing netdev. It will be unrefed by the destroy callback. */
+                netdev_ref(netdev);
 
-                if (r < 0)
-                        log_netdev_error_errno(netdev, r, "Failed create resolver: %m");
+                LIST_REMOVE(endpoints, w->unresolved_endpoints, endpoint);
         }
 }
 
-
 static int netdev_wireguard_post_create(NetDev *netdev, Link *link, sd_netlink_message *m) {
         Wireguard *w;
 
@@ -335,7 +381,7 @@ static int netdev_wireguard_post_create(NetDev *netdev, Link *link, sd_netlink_m
         w = WIREGUARD(netdev);
         assert(w);
 
-        set_wireguard_interface(netdev);
+        (void) wireguard_set_interface(netdev);
         resolve_endpoints(netdev);
         return 0;
 }
@@ -467,7 +513,6 @@ int config_parse_wireguard_preshared_key(const char *unit,
                                    peer->preshared_key);
 }
 
-
 int config_parse_wireguard_public_key(const char *unit,
                                       const char *filename,
                                       unsigned line,
@@ -548,12 +593,15 @@ int config_parse_wireguard_allowed_ips(const char *unit,
                         return 0;
                 }
 
-                ipmask = new0(WireguardIPmask, 1);
+                ipmask = new(WireguardIPmask, 1);
                 if (!ipmask)
                         return log_oom();
-                ipmask->family = family;
-                ipmask->ip.in6 = addr.in6;
-                ipmask->cidr = prefixlen;
+
+                *ipmask = (WireguardIPmask) {
+                        .family = family,
+                        .ip.in6 = addr.in6,
+                        .cidr = prefixlen,
+                };
 
                 LIST_PREPEND(ipmasks, peer->ipmasks, ipmask);
         }
@@ -589,10 +637,6 @@ int config_parse_wireguard_endpoint(const char *unit,
         if (!peer)
                 return log_oom();
 
-        endpoint = new0(WireguardEndpoint, 1);
-        if (!endpoint)
-                return log_oom();
-
         if (rvalue[0] == '[') {
                 begin = &rvalue[1];
                 end = strchr(rvalue, ']');
@@ -626,12 +670,17 @@ int config_parse_wireguard_endpoint(const char *unit,
         if (!port)
                 return log_oom();
 
-        endpoint->peer = TAKE_PTR(peer);
-        endpoint->host = TAKE_PTR(host);
-        endpoint->port = TAKE_PTR(port);
-        endpoint->netdev = netdev_ref(data);
-        LIST_PREPEND(endpoints, w->unresolved_endpoints, endpoint);
-        endpoint = NULL;
+        endpoint = new(WireguardEndpoint, 1);
+        if (!endpoint)
+                return log_oom();
+
+        *endpoint = (WireguardEndpoint) {
+                .peer = TAKE_PTR(peer),
+                .host = TAKE_PTR(host),
+                .port = TAKE_PTR(port),
+                .netdev = data,
+        };
+        LIST_PREPEND(endpoints, w->unresolved_endpoints, TAKE_PTR(endpoint));
 
         return 0;
 }
@@ -690,11 +739,11 @@ static void wireguard_done(NetDev *netdev) {
         Wireguard *w;
         WireguardPeer *peer;
         WireguardIPmask *mask;
+        WireguardEndpoint *e;
 
         assert(netdev);
         w = WIREGUARD(netdev);
-        assert(!w->unresolved_endpoints);
-        w->resolve_retry_event_source = sd_event_source_unref(w->resolve_retry_event_source);
+        assert(w);
 
         while ((peer = w->peers)) {
                 LIST_REMOVE(peers, w->peers, peer);
@@ -704,6 +753,16 @@ static void wireguard_done(NetDev *netdev) {
                 }
                 free(peer);
         }
+
+        while ((e = w->unresolved_endpoints)) {
+                LIST_REMOVE(endpoints, w->unresolved_endpoints, e);
+                wireguard_endpoint_free(e);
+        }
+
+        while ((e = w->failed_endpoints)) {
+                LIST_REMOVE(endpoints, w->failed_endpoints, e);
+                wireguard_endpoint_free(e);
+        }
 }
 
 const NetDevVTable wireguard_vtable = {