From 46c0dbe09fc6dec1acddc7a8ff2ccc43022ad2be Mon Sep 17 00:00:00 2001 From: =?utf8?q?Oto=20=C5=A0=C5=A5=C3=A1va?= Date: Wed, 11 May 2022 10:10:25 +0200 Subject: [PATCH] daemon/network: reintroduce net.close() wildcard semantics --- daemon/network.c | 113 +++++++++++++++++++++++++++++++++++++++++------ lib/utils.c | 35 +++++++++++++++ lib/utils.h | 5 +++ 3 files changed, 140 insertions(+), 13 deletions(-) diff --git a/daemon/network.c b/daemon/network.c index bfce0a9e2..8e4bc3f3e 100644 --- a/daemon/network.c +++ b/daemon/network.c @@ -46,9 +46,10 @@ struct __attribute__((packed)) endpoint_key_ifname { /** Used for reserving enough storage for `endpoint_key`. */ struct endpoint_key_storage { union { + enum endpoint_key_type type; struct endpoint_key_sockaddr sa; struct endpoint_key_ifname ifname; - const char bytes[1]; /* for easier casting */ + char bytes[1]; /* for easier casting */ }; }; @@ -706,18 +707,10 @@ void network_proxy_reset(struct network *net) trie_clear(net->proxy_addrs6); } -int network_close(struct network *net, const char *addr_str, int port) +static int endpoints_close(struct network *net, + struct endpoint_key_storage *key, ssize_t keylen, + endpoint_array_t *ep_array, int port) { - auto_free struct sockaddr *addr = kr_straddr_socket(addr_str, port, NULL); - struct endpoint_key_storage key; - ssize_t keylen = endpoint_key_create(&key, addr_str, addr); - if (keylen < 0) - return keylen; - trie_val_t *val = trie_get_try(net->endpoints, key.bytes, keylen); - if (!val) - return kr_error(ENOENT); - endpoint_array_t *ep_array = *val; - size_t i = 0; bool matched = false; /*< at least one match */ while (i < ep_array->len) { @@ -735,6 +728,100 @@ int network_close(struct network *net, const char *addr_str, int port) return kr_error(ENOENT); } + return kr_ok(); +} + +static bool endpoint_key_addr_matches(struct endpoint_key_storage *key_a, + struct endpoint_key_storage *key_b) +{ + if (key_a->type != key_b->type) + return false; + + if (key_a->type == ENDPOINT_KEY_IFNAME) + return strncmp(key_a->ifname.ifname, + key_b->ifname.ifname, + sizeof(key_a->ifname.ifname)) == 0; + + if (key_a->type == ENDPOINT_KEY_SOCKADDR) { + return kr_sockaddr_key_same_addr( + key_a->sa.sa_key.bytes, key_b->sa.sa_key.bytes); + } + + kr_assert(false); + return kr_error(EINVAL); +} + +struct endpoint_key_with_len { + struct endpoint_key_storage key; + size_t keylen; +}; +typedef array_t(struct endpoint_key_with_len) endpoint_key_array_t; + +struct endpoint_close_wildcard_context { + struct network *net; + struct endpoint_key_storage *match_key; + endpoint_key_array_t del; + int ret; +}; + +static int endpoints_close_wildcard(const char *s_key, uint32_t keylen, trie_val_t *val, void *baton) +{ + struct endpoint_close_wildcard_context *ctx = baton; + struct endpoint_key_storage *key = (struct endpoint_key_storage *)s_key; + + if (!endpoint_key_addr_matches(key, ctx->match_key)) + return kr_ok(); + + endpoint_array_t *ep_array = *val; + int ret = endpoints_close(ctx->net, key, keylen, ep_array, -1); + if (ret) + ctx->ret = ret; + + if (ep_array->len == 0) { + struct endpoint_key_with_len to_del = { + .key = *key, + .keylen = keylen + }; + array_push(ctx->del, to_del); + } + + return kr_ok(); +} + +int network_close(struct network *net, const char *addr_str, int port) +{ + auto_free struct sockaddr *addr = kr_straddr_socket(addr_str, port, NULL); + struct endpoint_key_storage key; + ssize_t keylen = endpoint_key_create(&key, addr_str, addr); + if (keylen < 0) + return keylen; + + if (port < 0) { + struct endpoint_close_wildcard_context ctx = { + .net = net, + .match_key = &key + }; + array_init(ctx.del); + trie_apply_with_key(net->endpoints, endpoints_close_wildcard, &ctx); + for (size_t i = 0; i < ctx.del.len; i++) { + trie_val_t val; + trie_del(net->endpoints, + ctx.del.at[i].key.bytes, ctx.del.at[i].keylen, + &val); + if (val) { + array_clear(*(endpoint_array_t *) val); + free(val); + } + } + return ctx.ret; + } + + trie_val_t *val = trie_get_try(net->endpoints, key.bytes, keylen); + if (!val) + return kr_error(ENOENT); + endpoint_array_t *ep_array = *val; + int ret = endpoints_close(net, &key, keylen, ep_array, port); + /* Collapse key if it has no endpoint. */ if (ep_array->len == 0) { array_clear(*ep_array); @@ -742,7 +829,7 @@ int network_close(struct network *net, const char *addr_str, int port) trie_del(net->endpoints, key.bytes, keylen, NULL); } - return kr_ok(); + return ret; } void network_new_hostname(struct network *net, struct engine *engine) diff --git a/lib/utils.c b/lib/utils.c index f6f5f4408..50e6bfdac 100644 --- a/lib/utils.c +++ b/lib/utils.c @@ -393,6 +393,41 @@ struct sockaddr *kr_sockaddr_from_key(struct sockaddr_storage *dst, } } +bool kr_sockaddr_key_same_addr(const char *key_a, const char *key_b) +{ + const struct kr_sockaddr_key *kkey_a = (struct kr_sockaddr_key *) key_a; + const struct kr_sockaddr_key *kkey_b = (struct kr_sockaddr_key *) key_b; + + if (kkey_a->family != kkey_b->family) + return false; + + ptrdiff_t offset; + switch (kkey_a->family) { + case AF_INET: + offset = offsetof(struct kr_sockaddr_in_key, address); + break; + case AF_INET6: + offset = offsetof(struct kr_sockaddr_in6_key, address); + break; + + case AF_UNIX:; + const struct kr_sockaddr_un_key *unkey_a = + (struct kr_sockaddr_un_key *) key_a; + const struct kr_sockaddr_un_key *unkey_b = + (struct kr_sockaddr_un_key *) key_b; + + return strncmp(unkey_a->path, unkey_b->path, + sizeof(unkey_a->path)) == 0; + + default: + kr_assert(false); + return false; + } + + size_t len = kr_family_len(kkey_a->family); + return memcmp(key_a + offset, key_b + offset, len) == 0; +} + int kr_sockaddr_cmp(const struct sockaddr *left, const struct sockaddr *right) { if (!left || !right) { diff --git a/lib/utils.h b/lib/utils.h index 61c6f695a..49cd4a8dc 100644 --- a/lib/utils.h +++ b/lib/utils.h @@ -289,6 +289,11 @@ KR_EXPORT struct sockaddr *kr_sockaddr_from_key(struct sockaddr_storage *dst, const char *key); +/** Checks whether the two keys represent the same address; + * does NOT compare the ports. */ +KR_EXPORT +bool kr_sockaddr_key_same_addr(const char *key_a, const char *key_b); + /** Compare two given sockaddr. * return 0 - addresses are equal, error code otherwise. */ -- 2.47.2