]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon/network: reintroduce net.close() wildcard semantics
authorOto Šťáva <oto.stava@nic.cz>
Wed, 11 May 2022 08:10:25 +0000 (10:10 +0200)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Fri, 13 May 2022 11:16:43 +0000 (13:16 +0200)
daemon/network.c
lib/utils.c
lib/utils.h

index bfce0a9e290aee4e3b77d67aad69fb997b59e523..8e4bc3f3ebcf501e8de4f383102e6eacae549b31 100644 (file)
@@ -46,9 +46,10 @@ struct __attribute__((packed)) endpoint_key_ifname {
 /** Used for reserving enough storage for `endpoint_key`. */
 struct endpoint_key_storage {
        union {
+               enum endpoint_key_type type;
                struct endpoint_key_sockaddr sa;
                struct endpoint_key_ifname ifname;
-               const char bytes[1]; /* for easier casting */
+               char bytes[1]; /* for easier casting */
        };
 };
 
@@ -706,18 +707,10 @@ void network_proxy_reset(struct network *net)
        trie_clear(net->proxy_addrs6);
 }
 
-int network_close(struct network *net, const char *addr_str, int port)
+static int endpoints_close(struct network *net,
+                           struct endpoint_key_storage *key, ssize_t keylen,
+                           endpoint_array_t *ep_array, int port)
 {
-       auto_free struct sockaddr *addr = kr_straddr_socket(addr_str, port, NULL);
-       struct endpoint_key_storage key;
-       ssize_t keylen = endpoint_key_create(&key, addr_str, addr);
-       if (keylen < 0)
-               return keylen;
-       trie_val_t *val = trie_get_try(net->endpoints, key.bytes, keylen);
-       if (!val)
-               return kr_error(ENOENT);
-       endpoint_array_t *ep_array = *val;
-
        size_t i = 0;
        bool matched = false; /*< at least one match */
        while (i < ep_array->len) {
@@ -735,6 +728,100 @@ int network_close(struct network *net, const char *addr_str, int port)
                return kr_error(ENOENT);
        }
 
+       return kr_ok();
+}
+
+static bool endpoint_key_addr_matches(struct endpoint_key_storage *key_a,
+                                      struct endpoint_key_storage *key_b)
+{
+       if (key_a->type != key_b->type)
+               return false;
+
+       if (key_a->type == ENDPOINT_KEY_IFNAME)
+               return strncmp(key_a->ifname.ifname,
+                              key_b->ifname.ifname,
+                              sizeof(key_a->ifname.ifname)) == 0;
+
+       if (key_a->type == ENDPOINT_KEY_SOCKADDR) {
+               return kr_sockaddr_key_same_addr(
+                               key_a->sa.sa_key.bytes, key_b->sa.sa_key.bytes);
+       }
+
+       kr_assert(false);
+       return kr_error(EINVAL);
+}
+
+struct endpoint_key_with_len {
+       struct endpoint_key_storage key;
+       size_t keylen;
+};
+typedef array_t(struct endpoint_key_with_len) endpoint_key_array_t;
+
+struct endpoint_close_wildcard_context {
+       struct network *net;
+       struct endpoint_key_storage *match_key;
+       endpoint_key_array_t del;
+       int ret;
+};
+
+static int endpoints_close_wildcard(const char *s_key, uint32_t keylen, trie_val_t *val, void *baton)
+{
+       struct endpoint_close_wildcard_context *ctx = baton;
+       struct endpoint_key_storage *key = (struct endpoint_key_storage *)s_key;
+
+       if (!endpoint_key_addr_matches(key, ctx->match_key))
+               return kr_ok();
+
+       endpoint_array_t *ep_array = *val;
+       int ret = endpoints_close(ctx->net, key, keylen, ep_array, -1);
+       if (ret)
+               ctx->ret = ret;
+
+       if (ep_array->len == 0) {
+               struct endpoint_key_with_len to_del = {
+                       .key = *key,
+                       .keylen = keylen
+               };
+               array_push(ctx->del, to_del);
+       }
+
+       return kr_ok();
+}
+
+int network_close(struct network *net, const char *addr_str, int port)
+{
+       auto_free struct sockaddr *addr = kr_straddr_socket(addr_str, port, NULL);
+       struct endpoint_key_storage key;
+       ssize_t keylen = endpoint_key_create(&key, addr_str, addr);
+       if (keylen < 0)
+               return keylen;
+
+       if (port < 0) {
+               struct endpoint_close_wildcard_context ctx = {
+                       .net = net,
+                       .match_key = &key
+               };
+               array_init(ctx.del);
+               trie_apply_with_key(net->endpoints, endpoints_close_wildcard, &ctx);
+               for (size_t i = 0; i < ctx.del.len; i++) {
+                       trie_val_t val;
+                       trie_del(net->endpoints,
+                                ctx.del.at[i].key.bytes, ctx.del.at[i].keylen,
+                                &val);
+                       if (val) {
+                               array_clear(*(endpoint_array_t *) val);
+                               free(val);
+                       }
+               }
+               return ctx.ret;
+       }
+
+       trie_val_t *val = trie_get_try(net->endpoints, key.bytes, keylen);
+       if (!val)
+               return kr_error(ENOENT);
+       endpoint_array_t *ep_array = *val;
+       int ret = endpoints_close(net, &key, keylen, ep_array, port);
+
        /* Collapse key if it has no endpoint. */
        if (ep_array->len == 0) {
                array_clear(*ep_array);
@@ -742,7 +829,7 @@ int network_close(struct network *net, const char *addr_str, int port)
                trie_del(net->endpoints, key.bytes, keylen, NULL);
        }
 
-       return kr_ok();
+       return ret;
 }
 
 void network_new_hostname(struct network *net, struct engine *engine)
index f6f5f4408e899d57079513e410ce8cf3e94dac2f..50e6bfdac7485f0cfcd1187fba5e119e490bec2c 100644 (file)
@@ -393,6 +393,41 @@ struct sockaddr *kr_sockaddr_from_key(struct sockaddr_storage *dst,
        }
 }
 
+bool kr_sockaddr_key_same_addr(const char *key_a, const char *key_b)
+{
+       const struct kr_sockaddr_key *kkey_a = (struct kr_sockaddr_key *) key_a;
+       const struct kr_sockaddr_key *kkey_b = (struct kr_sockaddr_key *) key_b;
+
+       if (kkey_a->family != kkey_b->family)
+               return false;
+
+       ptrdiff_t offset;
+       switch (kkey_a->family) {
+               case AF_INET:
+                       offset = offsetof(struct kr_sockaddr_in_key, address);
+                       break;
+               case AF_INET6:
+                       offset = offsetof(struct kr_sockaddr_in6_key, address);
+                       break;
+
+               case AF_UNIX:;
+                       const struct kr_sockaddr_un_key *unkey_a =
+                               (struct kr_sockaddr_un_key *) key_a;
+                       const struct kr_sockaddr_un_key *unkey_b =
+                               (struct kr_sockaddr_un_key *) key_b;
+
+                       return strncmp(unkey_a->path, unkey_b->path,
+                                      sizeof(unkey_a->path)) == 0;
+
+               default:
+                       kr_assert(false);
+                       return false;
+       }
+
+       size_t len = kr_family_len(kkey_a->family);
+       return memcmp(key_a + offset, key_b + offset, len) == 0;
+}
+
 int kr_sockaddr_cmp(const struct sockaddr *left, const struct sockaddr *right)
 {
        if (!left || !right) {
index 61c6f695a21f29a9fb85ce3c29bce580ec90e7b0..49cd4a8dcd0f6c16da6a5e8c31692d3672273061 100644 (file)
@@ -289,6 +289,11 @@ KR_EXPORT
 struct sockaddr *kr_sockaddr_from_key(struct sockaddr_storage *dst,
                                       const char *key);
 
+/** Checks whether the two keys represent the same address;
+ * does NOT compare the ports. */
+KR_EXPORT
+bool kr_sockaddr_key_same_addr(const char *key_a, const char *key_b);
+
 /** Compare two given sockaddr.
  * return 0 - addresses are equal, error code otherwise.
  */