]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon/network: Use trie_t instead of map_t for network endpoints
authorOto Šťáva <oto.stava@nic.cz>
Wed, 20 Apr 2022 08:08:28 +0000 (10:08 +0200)
committerOto Šťáva <oto.stava@nic.cz>
Wed, 11 May 2022 08:12:50 +0000 (10:12 +0200)
daemon/bindings/net.c
daemon/network.c
daemon/network.h
lib/utils.c
lib/utils.h

index c6b0f5dc56e1d6753236d64918544451c917d7b1..cd0a3e366f54346c2e46371a2c4fa69da06f39bb 100644 (file)
@@ -8,17 +8,18 @@
 #include "contrib/cleanup.h"
 #include "daemon/network.h"
 #include "daemon/tls.h"
+#include "lib/utils.h"
 
 #include <stdlib.h>
 
 #define PROXY_DATA_STRLEN (INET6_ADDRSTRLEN + 1 + 3 + 1)
 
 /** Table and next index on top of stack -> append entries for given endpoint_array_t. */
-static int net_list_add(const char *key, void *val, void *ext)
+static int net_list_add(const char *b_key, uint32_t key_len, trie_val_t *val, void *ext)
 {
+       endpoint_array_t *ep_array = *val;
        lua_State *L = (lua_State *)ext;
        lua_Integer i = lua_tointeger(L, -1);
-       endpoint_array_t *ep_array = val;
        for (int j = 0; j < ep_array->len; ++j) {
                struct endpoint *ep = &ep_array->at[j];
                lua_newtable(L);  // connection tuple
@@ -57,7 +58,15 @@ static int net_list_add(const char *key, void *val, void *ext)
                }
                lua_setfield(L, -2, "family");
 
-               lua_pushstring(L, key);
+               const char *ip_str_const = network_endpoint_key_str((struct endpoint_key *) b_key);
+               kr_require(ip_str_const);
+               auto_free char *ip_str = strdup(ip_str_const);
+               kr_require(ip_str);
+               char *hm = strchr(ip_str, '#');
+               if (hm) /* Omit port */
+                       *hm = '\0';
+               lua_pushstring(L, ip_str);
+
                if (ep->family == AF_INET || ep->family == AF_INET6) {
                        lua_setfield(L, -2, "ip");
                        lua_pushboolean(L, ep->flags.freebind);
@@ -101,7 +110,7 @@ static int net_list(lua_State *L)
 {
        lua_newtable(L);
        lua_pushinteger(L, 1);
-       map_walk(&the_worker->engine->net.endpoints, net_list_add, L);
+       trie_apply_with_key(the_worker->engine->net.endpoints, net_list_add, L);
        lua_pop(L, 1);
        return 1;
 }
index 312a1cfead6ad34e7a249559db850fb0157e7d31..bfce0a9e290aee4e3b77d67aad69fb997b59e523 100644 (file)
@@ -4,10 +4,12 @@
 
 #include "daemon/network.h"
 
+#include "contrib/cleanup.h"
 #include "daemon/bindings/impl.h"
 #include "daemon/io.h"
 #include "daemon/tls.h"
 #include "daemon/worker.h"
+#include "lib/utils.h"
 
 #if ENABLE_XDP
        #include <libknot/xdp/eth.h>
 #include <sys/un.h>
 #include <unistd.h>
 
+/** Determines the type of `struct endpoint_key`. */
+enum endpoint_key_type
+{
+       ENDPOINT_KEY_SOCKADDR = 1,
+       ENDPOINT_KEY_IFNAME   = 2,
+};
+
+/** Used as a key in the `struct network::endpoints` trie. */
+struct endpoint_key {
+       enum endpoint_key_type type;
+       char data[];
+};
+
+struct __attribute__((packed)) endpoint_key_sockaddr {
+       enum endpoint_key_type type;
+       struct kr_sockaddr_key_storage sa_key;
+};
+
+struct __attribute__((packed)) endpoint_key_ifname {
+       enum endpoint_key_type type;
+       char ifname[128];
+};
+
+/** Used for reserving enough storage for `endpoint_key`. */
+struct endpoint_key_storage {
+       union {
+               struct endpoint_key_sockaddr sa;
+               struct endpoint_key_ifname ifname;
+               const char bytes[1]; /* for easier casting */
+       };
+};
+
+static_assert(_Alignof(struct endpoint_key) <= 4, "endpoint_key must be aligned to <=4");
+static_assert(_Alignof(struct endpoint_key_sockaddr) <= 4, "endpoint_key must be aligned to <=4");
+static_assert(_Alignof(struct endpoint_key_ifname) <= 4, "endpoint_key must be aligned to <=4");
+
 void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
 {
        if (net != NULL) {
                net->loop = loop;
-               net->endpoints = map_make(NULL);
+               net->endpoints = trie_create(NULL);
                net->endpoint_kinds = trie_create(NULL);
                net->proxy_all4 = false;
                net->proxy_all6 = false;
@@ -77,24 +115,29 @@ static int endpoint_open_lua_cb(struct network *net, struct endpoint *ep,
        return kr_ok();
 }
 
-static int engage_endpoint_array(const char *key, void *endpoints, void *net)
+static int engage_endpoint_array(const char *b_key, uint32_t key_len, trie_val_t *val, void *net)
 {
-       endpoint_array_t *eps = (endpoint_array_t *)endpoints;
+       const char *log_addr = network_endpoint_key_str((struct endpoint_key *) b_key);
+       if (!log_addr)
+               log_addr = "[unknown]";
+
+       endpoint_array_t *eps = *val;
        for (int i = 0; i < eps->len; ++i) {
                struct endpoint *ep = &eps->at[i];
                const bool match = !ep->engaged && ep->flags.kind;
                if (!match) continue;
-               int ret = endpoint_open_lua_cb(net, ep, key);
+               int ret = endpoint_open_lua_cb(net, ep, log_addr);
                if (ret) return ret;
        }
        return 0;
 }
+
 int network_engage_endpoints(struct network *net)
 {
        if (net->missing_kind_is_error)
                return kr_ok(); /* maybe weird, but let's make it idempotent */
        net->missing_kind_is_error = true;
-       int ret = map_walk(&net->endpoints, engage_endpoint_array, net);
+       int ret = trie_apply_with_key(net->endpoints, engage_endpoint_array, net);
        if (ret) {
                net->missing_kind_is_error = false; /* avoid the same errors when closing */
                return ret;
@@ -102,6 +145,25 @@ int network_engage_endpoints(struct network *net)
        return kr_ok();
 }
 
+const char *network_endpoint_key_str(const struct endpoint_key *key)
+{
+       switch (key->type)
+       {
+       case ENDPOINT_KEY_SOCKADDR:;
+               const struct endpoint_key_sockaddr *sa_key =
+                       (struct endpoint_key_sockaddr *) key;
+               struct sockaddr_storage sa_storage;
+               struct sockaddr *sa = kr_sockaddr_from_key(&sa_storage, (const char *) &sa_key->sa_key);
+               return kr_straddr(sa);
+       case ENDPOINT_KEY_IFNAME:;
+               const struct endpoint_key_ifname *if_key =
+                       (struct endpoint_key_ifname *) key;
+               return if_key->ifname;
+       default:
+               kr_assert(false);
+               return NULL;
+       }
+}
 
 /** Notify the registered function about endpoint about to be closed. */
 static void endpoint_close_lua_cb(struct network *net, struct endpoint *ep)
@@ -174,18 +236,18 @@ static void endpoint_close(struct network *net, struct endpoint *ep, bool force)
 }
 
 /** Endpoint visitor (see @file map.h) */
-static int close_key(const char *key, void *val, void *net)
+static int close_key(trie_val_t *val, void* net)
 {
-       endpoint_array_t *ep_array = val;
+       endpoint_array_t *ep_array = *val;
        for (int i = 0; i < ep_array->len; ++i) {
                endpoint_close(net, &ep_array->at[i], true);
        }
        return 0;
 }
 
-static int free_key(const char *key, void *val, void *ext)
+static int free_key(trie_val_t *val, void* ext)
 {
-       endpoint_array_t *ep_array = val;
+       endpoint_array_t *ep_array = *val;
        array_clear(*ep_array);
        free(ep_array);
        return kr_ok();
@@ -201,9 +263,9 @@ int kind_unregister(trie_val_t *tv, void *L)
 void network_close_force(struct network *net)
 {
        if (net != NULL) {
-               map_walk(&net->endpoints, close_key, net);
-               map_walk(&net->endpoints, free_key, 0);
-               map_clear(&net->endpoints);
+               trie_apply(net->endpoints, close_key, net);
+               trie_apply(net->endpoints, free_key, NULL);
+               trie_clear(net->endpoints);
        }
 }
 
@@ -224,6 +286,7 @@ void network_deinit(struct network *net)
                network_close_force(net);
                trie_apply(net->endpoint_kinds, kind_unregister, the_worker->engine->L);
                trie_free(net->endpoint_kinds);
+               trie_free(net->endpoints);
                network_proxy_free_addr_data(net->proxy_addrs4);
                trie_free(net->proxy_addrs4);
                network_proxy_free_addr_data(net->proxy_addrs6);
@@ -238,21 +301,44 @@ void network_deinit(struct network *net)
        }
 }
 
+static ssize_t endpoint_key_create(struct endpoint_key_storage *dst,
+                                   const char *addr_str,
+                                   const struct sockaddr *sa)
+{
+       memset(dst, 0, sizeof(*dst));
+       if (sa) {
+               struct endpoint_key_sockaddr *key = &dst->sa;
+               key->type = ENDPOINT_KEY_SOCKADDR;
+               ssize_t keylen = kr_sockaddr_key(&key->sa_key, sa);
+               if (keylen < 0)
+                       return keylen;
+               return sizeof(struct endpoint_key) + keylen;
+       } else {
+               struct endpoint_key_ifname *key = &dst->ifname;
+               key->type = ENDPOINT_KEY_IFNAME;
+               strncpy(key->ifname, addr_str, sizeof(key->ifname) - 1);
+               return sizeof(struct endpoint_key) + strnlen(key->ifname, sizeof(key->ifname));
+       }
+}
+
 /** Fetch or create endpoint array and insert endpoint (shallow memcpy). */
-static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep)
+static int insert_endpoint(struct network *net, const char *addr_str,
+                           const struct sockaddr *addr, struct endpoint *ep)
 {
        /* Fetch or insert address into map */
-       endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
-       if (ep_array == NULL) {
+       struct endpoint_key_storage key;
+       ssize_t keylen = endpoint_key_create(&key, addr_str, addr);
+       if (keylen < 0)
+               return keylen;
+       trie_val_t *val = trie_get_ins(net->endpoints, key.bytes, keylen);
+       endpoint_array_t *ep_array;
+       if (*val) {
+               ep_array = *val;
+       } else {
                ep_array = malloc(sizeof(*ep_array));
-               if (ep_array == NULL) {
-                       return kr_error(ENOMEM);
-               }
-               if (map_set(&net->endpoints, addr, ep_array) != 0) {
-                       free(ep_array);
-                       return kr_error(ENOMEM);
-               }
+               kr_require(ep_array);
                array_init(*ep_array);
+               *val = ep_array;
        }
 
        if (array_reserve(*ep_array, ep_array->len + 1)) {
@@ -269,17 +355,16 @@ static int open_endpoint(struct network *net, const char *addr_str,
 {
        const bool is_control = ep->flags.kind && strcmp(ep->flags.kind, "control") == 0;
        const bool is_xdp     = ep->family == AF_XDP;
-       bool ok = is_xdp
-               ? sa == NULL && ep->fd == -1 && ep->nic_queue >= 0
-                       && ep->flags.sock_type == SOCK_DGRAM && !ep->flags.tls
-               : (sa != NULL) != (ep->fd != -1);
+       bool ok = (!is_xdp)
+               || (sa == NULL && ep->fd == -1 && ep->nic_queue >= 0
+                       && ep->flags.sock_type == SOCK_DGRAM && !ep->flags.tls);
        if (kr_fails_assert(ok))
                return kr_error(EINVAL);
        if (ep->handle) {
                return kr_error(EEXIST);
        }
 
-       if (sa) {
+       if (sa && ep->fd == -1) {
                if (sa->sa_family == AF_UNIX) {
                        struct sockaddr_un *sun = (struct sockaddr_un*)sa;
                        char *dirc = strdup(sun->sun_path);
@@ -363,16 +448,24 @@ finish_ret:
  * Beware that there might be multiple matches, though that's not common.
  * The matching isn't really precise in the sense that it might not find
  * and endpoint that would *collide* the passed one. */
-static struct endpoint * endpoint_get(struct network *net, const char *addr,
-                                       uint16_t port, endpoint_flags_t flags)
+static struct endpoint * endpoint_get(struct network *net,
+                                      const char *addr_str,
+                                      const struct sockaddr *sa,
+                                      endpoint_flags_t flags)
 {
-       endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
-       if (!ep_array) {
+       struct endpoint_key_storage key;
+       ssize_t keylen = endpoint_key_create(&key, addr_str, sa);
+       if (keylen < 0)
                return NULL;
-       }
+       trie_val_t *val = trie_get_try(net->endpoints, key.bytes, keylen);
+       if (!val)
+               return NULL;
+       endpoint_array_t *ep_array = *val;
+
+       uint16_t port = kr_inaddr_port(sa);
        for (int i = 0; i < ep_array->len; ++i) {
                struct endpoint *ep = &ep_array->at[i];
-               if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) {
+               if ((flags.xdp || ep->port == port) && endpoint_flags_eq(ep->flags, flags)) {
                        return ep;
                }
        }
@@ -383,11 +476,11 @@ static struct endpoint * endpoint_get(struct network *net, const char *addr,
  *  \note in XDP case addr_str is interface name
  *  \note ownership of ep.flags.* is taken on success. */
 static int create_endpoint(struct network *net, const char *addr_str,
-                               struct endpoint *ep, const struct sockaddr *sa)
+                           struct endpoint *ep, const struct sockaddr *sa)
 {
        int ret = open_endpoint(net, addr_str, ep, sa);
        if (ret == 0) {
-               ret = insert_endpoint(net, addr_str, ep);
+               ret = insert_endpoint(net, addr_str, sa, ep);
        }
        if (ret != 0 && ep->handle) {
                endpoint_close(net, ep, false);
@@ -448,7 +541,7 @@ int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags)
 
        /* always create endpoint for supervisor supplied fd
         * even if addr+port is not unique */
-       return create_endpoint(net, addr_str, &ep, NULL);
+       return create_endpoint(net, addr_str, &ep, (struct sockaddr *) &ss);
 }
 
 /** Try selecting XDP queue automatically. */
@@ -501,7 +594,7 @@ int network_listen(struct network *net, const char *addr, uint16_t port,
        }
        // XDP: if addr failed to parse as address, we assume it's an interface name.
 
-       if (endpoint_get(net, addr, port, flags)) {
+       if (endpoint_get(net, addr, sa, flags)) {
                return kr_error(EADDRINUSE); // Already listening
        }
 
@@ -613,12 +706,17 @@ void network_proxy_reset(struct network *net)
        trie_clear(net->proxy_addrs6);
 }
 
-int network_close(struct network *net, const char *addr, int port)
+int network_close(struct network *net, const char *addr_str, int port)
 {
-       endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
-       if (!ep_array) {
+       auto_free struct sockaddr *addr = kr_straddr_socket(addr_str, port, NULL);
+       struct endpoint_key_storage key;
+       ssize_t keylen = endpoint_key_create(&key, addr_str, addr);
+       if (keylen < 0)
+               return keylen;
+       trie_val_t *val = trie_get_try(net->endpoints, key.bytes, keylen);
+       if (!val)
                return kr_error(ENOENT);
-       }
+       endpoint_array_t *ep_array = *val;
 
        size_t i = 0;
        bool matched = false; /*< at least one match */
@@ -641,7 +739,7 @@ int network_close(struct network *net, const char *addr, int port)
        if (ep_array->len == 0) {
                array_clear(*ep_array);
                free(ep_array);
-               map_del(&net->endpoints, addr);
+               trie_del(net->endpoints, key.bytes, keylen, NULL);
        }
 
        return kr_ok();
@@ -664,10 +762,10 @@ void network_new_hostname(struct network *net, struct engine *engine)
 }
 
 #ifdef SO_ATTACH_BPF
-static int set_bpf_cb(const char *key, void *val, void *ext)
+static int set_bpf_cb(trie_val_t *val, void *ctx)
 {
-       endpoint_array_t *endpoints = (endpoint_array_t *)val;
-       int *bpffd = (int *)ext;
+       endpoint_array_t *endpoints = *val;
+       int *bpffd = (int *)ctx;
        if (kr_fails_assert(endpoints && bpffd))
                return kr_error(EINVAL);
 
@@ -689,7 +787,7 @@ static int set_bpf_cb(const char *key, void *val, void *ext)
 int network_set_bpf(struct network *net, int bpf_fd)
 {
 #ifdef SO_ATTACH_BPF
-       if (map_walk(&net->endpoints, set_bpf_cb, &bpf_fd) != 0) {
+       if (trie_apply(net->endpoints, set_bpf_cb, &bpf_fd) != 0) {
                /* set_bpf_cb() has returned error. */
                network_clear_bpf(net);
                return 0;
@@ -704,9 +802,9 @@ int network_set_bpf(struct network *net, int bpf_fd)
 }
 
 #ifdef SO_DETACH_BPF
-static int clear_bpf_cb(const char *key, void *val, void *ext)
+static int clear_bpf_cb(trie_val_t *val, void *ctx)
 {
-       endpoint_array_t *endpoints = (endpoint_array_t *)val;
+       endpoint_array_t *endpoints = *val;
        if (kr_fails_assert(endpoints))
                return kr_error(EINVAL);
 
@@ -730,7 +828,7 @@ static int clear_bpf_cb(const char *key, void *val, void *ext)
 void network_clear_bpf(struct network *net)
 {
 #ifdef SO_DETACH_BPF
-       map_walk(&net->endpoints, clear_bpf_cb, NULL);
+       trie_apply(net->endpoints, clear_bpf_cb, NULL);
 #else
        kr_log_error(NETWORK, "SO_DETACH_BPF socket option doesn't supported\n");
        (void)net;
index 0e764b2780eda3d1c265e06136775563490b29c3..4399c044d62762c9a3a19e0bba74eec200de5f3e 100644 (file)
@@ -31,6 +31,8 @@ typedef struct {
        const char *kind; /**< tag for other types: "control" or module-handled kinds */
 } endpoint_flags_t;
 
+struct endpoint_key;
+
 static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2)
 {
        if (f1.sock_type != f2.sock_type)
@@ -78,9 +80,8 @@ struct network {
        uv_loop_t *loop;
 
        /** Map: address string -> endpoint_array_t.
-        * \note even same address-port-flags tuples may appear.
-        * TODO: trie_t, keyed on *binary* address-port pair. */
-       map_t endpoints;
+        * \note even same address-port-flags tuples may appear. */
+       trie_t *endpoints;
 
        /** Registry of callbacks for special endpoint kinds (for opening/closing).
         * Map: kind (lowercased) -> lua function ID converted to void *
@@ -150,6 +151,11 @@ void network_close_force(struct network *net);
  * This only does anything with struct endpoint::flags.kind != NULL. */
 int network_engage_endpoints(struct network *net);
 
+/** Returns a string representation of the specified endpoint key.
+ *
+ * The result points into key or is on static storage like for kr_straddr() */
+const char *network_endpoint_key_str(const struct endpoint_key *key);
+
 int network_set_tls_cert(struct network *net, const char *cert);
 int network_set_tls_key(struct network *net, const char *key);
 void network_new_hostname(struct network *net, struct engine *engine);
index 1f8db78e9528ff6701931fb1d14a073b6367d19d..b43d80aa7c3fae8b6f60e8fb5f3a869457dc1f5e 100644 (file)
@@ -378,6 +378,7 @@ struct sockaddr *kr_sockaddr_from_key(struct sockaddr_storage *dst,
                return (struct sockaddr *) addr_un;
 
        default:
+               kr_assert(false);
                return NULL;
        }
 }
@@ -486,7 +487,10 @@ int kr_straddr_family(const char *addr)
        if (strchr(addr, ':')) {
                return AF_INET6;
        }
-       return AF_INET;
+       if (strchr(addr, '.')) {
+               return AF_INET;
+       }
+       return kr_error(EINVAL);
 }
 
 int kr_family_len(int family)
@@ -531,7 +535,6 @@ struct sockaddr * kr_straddr_socket(const char *addr, int port, knot_mm_t *pool)
                return (struct sockaddr *)res;
        }
        default:
-               kr_assert(false);
                return NULL;
        }
 }
index d9aadfa13ef98d3b9add3dc08e5ad9b04d2f0ff3..61c6f695a21f29a9fb85ce3c29bce580ec90e7b0 100644 (file)
@@ -10,6 +10,7 @@
 #include <stdbool.h>
 #include <sys/socket.h>
 #include <sys/time.h>
+#include <sys/un.h>
 #include <netinet/in.h>
 #include <unistd.h>
 
 /** When knot_pkt is passed from cache without ->wire, this is the ->size. */
 static const size_t KR_PKT_SIZE_NOWIRE = -1;
 
+/** Maximum length (excluding null-terminator) of a presentation-form address
+ * returned by `kr_straddr`. */
+#define KR_STRADDR_MAXLEN 109
+
 /** Used for reserving enough space for the `kr_sockaddr_key` function
  * output. */
 struct kr_sockaddr_key_storage {
@@ -255,6 +260,7 @@ union kr_in_addr {
        struct in6_addr ip6;
 };
 
+/* TODO: rename kr_inaddr functions to kr_sockaddr */
 /** Address bytes for given family. */
 KR_EXPORT KR_PURE
 const char *kr_inaddr(const struct sockaddr *addr);
@@ -315,8 +321,12 @@ static inline char *kr_straddr(const struct sockaddr *addr)
 {
        if (kr_fails_assert(addr)) return NULL;
        /* We are the single-threaded application */
-       static char str[INET6_ADDRSTRLEN + 1 + 5 + 1];
-       size_t len = sizeof(str);
+       static char str[KR_STRADDR_MAXLEN + 1] = {0};
+       if (addr->sa_family == AF_UNIX) {
+               strncpy(str, ((struct sockaddr_un *) addr)->sun_path, sizeof(str) - 1);
+               return str;
+       }
+       size_t len = KR_STRADDR_MAXLEN;
        int ret = kr_inaddr_str(addr, str, &len);
        return ret != kr_ok() || len == 0 ? NULL : str;
 }