]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon/network: stop using combined UDP+TCP endpoints
authorVladimír Čunát <vladimir.cunat@nic.cz>
Mon, 11 Mar 2019 14:31:35 +0000 (15:31 +0100)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Tue, 12 Mar 2019 11:42:02 +0000 (12:42 +0100)
It was confusing, e.g. the new net.list() or net.bpf_set() were wrong.
Implementation cleanup: merge _fd variant into open_endpoint(),
as the code was repetitive and differed in unnecessary places.

daemon/network.c
daemon/network.h

index ed94f7d47de90af1e2949680a189589415dc0f40..bf18291afad47fb78bac5e907b8a14dc4fcf0ac4 100644 (file)
 #if defined(UV_VERSION_HEX)
 #if (__linux__ && SO_REUSEPORT)
   #define handle_init(type, loop, handle, family) do { \
-       uv_ ## type ## _init_ex((loop), (handle), (family)); \
+       uv_ ## type ## _init_ex((loop), (uv_ ## type ## _t *)(handle), (family)); \
        uv_os_fd_t hi_fd = 0; \
-       if (uv_fileno((uv_handle_t *)(handle), &hi_fd) == 0) { \
+       if (uv_fileno((handle), &hi_fd) == 0) { \
                int hi_on = 1; \
                int hi_ret = setsockopt(hi_fd, SOL_SOCKET, SO_REUSEPORT, &hi_on, sizeof(hi_on)); \
                if (hi_ret) { \
-                       return hi_ret; \
+                       return kr_error(errno); \
                } \
        } \
   } while (0)
 /* libuv 1.7.0+ is able to assign fd immediately */
 #else
   #define handle_init(type, loop, handle, family) do { \
-       uv_ ## type ## _init_ex((loop), (handle), (family)); \
+       uv_ ## type ## _init_ex((loop), (uv_ ## type ## _t *)(handle), (family)); \
   } while (0)
 #endif
 #else
   #define handle_init(type, loop, handle, family) \
-       uv_ ## type ## _init((loop), (handle))
+       uv_ ## type ## _init((loop), (uv_ ## type ## _t *)(handle))
 #endif
 
 void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
@@ -76,13 +76,9 @@ static void close_handle(uv_handle_t *handle, bool force)
 
 static int close_endpoint(struct endpoint *ep, bool force)
 {
-       if (ep->udp) {
-               close_handle((uv_handle_t *)ep->udp, force);
+       if (ep->handle) {
+               close_handle(ep->handle, force);
        }
-       if (ep->tcp) {
-               close_handle((uv_handle_t *)ep->tcp, force);
-       }
-
        free(ep);
        return kr_ok();
 }
@@ -143,90 +139,67 @@ static int insert_endpoint(struct network *net, const char *addr, struct endpoin
        return kr_ok();
 }
 
-/** Open endpoint protocols. */
-static int open_endpoint(struct network *net, struct endpoint *ep, struct sockaddr *sa, uint32_t flags)
+/** Open endpoint protocols.  ep->flags were pre-set. */
+static int open_endpoint(struct network *net, struct endpoint *ep,
+                        const struct sockaddr *sa, int fd)
 {
-       int ret = 0;
-       if (flags & NET_UDP) {
-               ep->udp = malloc(sizeof(*ep->udp));
-               if (!ep->udp) {
-                       return kr_error(ENOMEM);
-               }
-               memset(ep->udp, 0, sizeof(*ep->udp));
-               handle_init(udp, net->loop, ep->udp, sa->sa_family); /* can return! */
-               ret = udp_bind(ep->udp, sa);
-               if (ret != 0) {
-                       return ret;
-               }
-               ep->flags |= NET_UDP;
+       if ((sa != NULL) == (fd != -1)) {
+               assert(!EINVAL);
+               return kr_error(EINVAL);
        }
-       if (flags & NET_TCP) {
-               ep->tcp = malloc(sizeof(*ep->tcp));
-               if (!ep->tcp) {
-                       return kr_error(ENOMEM);
-               }
-               memset(ep->tcp, 0, sizeof(*ep->tcp));
-               handle_init(tcp, net->loop, ep->tcp, sa->sa_family); /* can return! */
-               if (flags & NET_TLS) {
-                       ret = tcp_bind_tls(ep->tcp, sa, net->tcp_backlog);
-                       ep->flags |= NET_TLS;
-               } else {
-                       ret = tcp_bind(ep->tcp, sa, net->tcp_backlog);
-               }
-               if (ret != 0) {
-                       return ret;
-               }
-               ep->flags |= NET_TCP;
+       if (ep->handle) {
+               return kr_error(EEXIST);
        }
-       return ret;
-}
 
-/** Open fd as endpoint. */
-static int open_endpoint_fd(struct network *net, struct endpoint *ep, int fd, int sock_type, bool use_tls)
-{
-       int ret = kr_ok();
-       if (sock_type == SOCK_DGRAM) {
-               if (use_tls) {
-                       /* we do not support TLS over UDP */
-                       return kr_error(EBADF);
-               }
-               if (ep->udp) {
-                       return kr_error(EEXIST);
+       if (ep->flags & NET_UDP) {
+               if (ep->flags & (NET_TCP | NET_TLS)) {
+                       assert(!EINVAL);
+                       return kr_error(EINVAL);
                }
-               ep->udp = malloc(sizeof(*ep->udp));
-               if (!ep->udp) {
+               uv_udp_t *ep_handle = calloc(1, sizeof(uv_udp_t));
+               ep->handle = (uv_handle_t *)ep_handle;
+               if (!ep->handle) {
                        return kr_error(ENOMEM);
                }
-               uv_udp_init(net->loop, ep->udp);
-               ret = udp_bindfd(ep->udp, fd);
-               if (ret != 0) {
-                       close_handle((uv_handle_t *)ep->udp, false);
+               if (sa) {
+                       handle_init(udp, net->loop, ep->handle, sa->sa_family);
+                               /*^^ can return! */
+                       return udp_bind(ep_handle, sa);
+               } else {
+                       int ret = uv_udp_init(net->loop, ep_handle);
+                       if (ret == 0) {
+                               ret = udp_bindfd(ep_handle, fd);
+                       }
                        return ret;
                }
-               ep->flags |= NET_UDP;
-               return kr_ok();
-       } else if (sock_type == SOCK_STREAM) {
-               if (ep->tcp) {
-                       return kr_error(EEXIST);
-               }
-               ep->tcp = malloc(sizeof(*ep->tcp));
-               if (!ep->tcp) {
+       } /* else */
+
+       if (ep->flags & NET_TCP) {
+               uv_tcp_t *ep_handle = calloc(1, sizeof(uv_tcp_t));
+               ep->handle = (uv_handle_t *)ep_handle;
+               if (!ep->handle) {
                        return kr_error(ENOMEM);
                }
-               uv_tcp_init(net->loop, ep->tcp);
-               if (use_tls) {
-                       ret = tcp_bindfd_tls(ep->tcp, fd, net->tcp_backlog);
-                       ep->flags |= NET_TLS;
+               if (sa) {
+                       handle_init(tcp, net->loop, ep->handle, sa->sa_family); /* can return! */
                } else {
-                       ret = tcp_bindfd(ep->tcp, fd, net->tcp_backlog);
+                       int ret = uv_tcp_init(net->loop, ep_handle);
+                       if (ret) {
+                               return ret;
+                       }
                }
-               if (ret != 0) {
-                       close_handle((uv_handle_t *)ep->tcp, false);
-                       return ret;
+               if (ep->flags & NET_TLS) {
+                       return sa
+                               ? tcp_bind_tls  (ep_handle, sa, net->tcp_backlog)
+                               : tcp_bindfd_tls(ep_handle, fd, net->tcp_backlog);
+               } else {
+                       return sa
+                               ? tcp_bind  (ep_handle, sa, net->tcp_backlog)
+                               : tcp_bindfd(ep_handle, fd, net->tcp_backlog);
                }
-               ep->flags |= NET_TCP;
-               return kr_ok();
-       }
+       } /* else */
+
+       assert(!EINVAL);
        return kr_error(EINVAL);
 }
 
@@ -246,21 +219,59 @@ static endpoint_array_t *network_get(struct network *net, const char *addr, uint
        return NULL;
 }
 
+/** \note pass either sa != NULL xor fd != -1 */
+static int create_endpoint(struct network *net, const char *addr_str,
+                               uint16_t port, uint16_t flags,
+                               const struct sockaddr *sa, int fd)
+{
+       /* Bind interfaces */
+       struct endpoint *ep = calloc(1, sizeof(*ep));
+       if (!ep) {
+               return kr_error(ENOMEM);
+       }
+       ep->flags = flags;
+       ep->port = port;
+       int ret = open_endpoint(net, ep, sa, fd);
+       if (ret == 0) {
+               ret = insert_endpoint(net, addr_str, ep);
+       }
+       if (ret != 0) {
+               close_endpoint(ep, false);
+       }
+       return ret;
+}
+
 int network_listen_fd(struct network *net, int fd, bool use_tls)
 {
-       /* Extract local address and socket type. */
-       int sock_type = SOCK_DGRAM;
+       /* Extract fd's socket type. */
+       int sock_type;
        socklen_t len = sizeof(sock_type);
        int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &len);
        if (ret != 0) {
+               return kr_error(errno);
+       }
+       uint16_t flags;
+       if (sock_type == SOCK_DGRAM) {
+               flags = NET_UDP;
+               if (use_tls) {
+                       assert(!EINVAL);
+                       return kr_error(EINVAL);
+               }
+       } else if (sock_type == SOCK_STREAM) {
+               flags = NET_TCP;
+               if (use_tls) {
+                       flags |= NET_TLS;
+               }
+       } else {
                return kr_error(EBADF);
        }
+
        /* Extract local address for this socket. */
        struct sockaddr_storage ss = { .ss_family = AF_UNSPEC };
        socklen_t addr_len = sizeof(ss);
        ret = getsockname(fd, (struct sockaddr *)&ss, &addr_len);
        if (ret != 0) {
-               return kr_error(EBADF);
+               return kr_error(errno);
        }
        int port = 0;
        char addr_str[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
@@ -276,19 +287,10 @@ int network_listen_fd(struct network *net, int fd, bool use_tls)
 
        /* always create endpoint for supervisor supplied fd
         * even if addr+port is not unique */
-       struct endpoint *ep = malloc(sizeof(*ep));
-       memset(ep, 0, sizeof(*ep));
-       ep->flags = NET_DOWN;
-       ep->port = port;
-       ret = insert_endpoint(net, addr_str, ep);
-       if (ret != 0) {
-               return ret;
-       }
-       /* Create a libuv struct for this socket. */
-       return open_endpoint_fd(net, ep, fd, sock_type, use_tls);
+       return create_endpoint(net, addr_str, port, flags, NULL, fd);
 }
 
-int network_listen(struct network *net, const char *addr, uint16_t port, uint32_t flags)
+int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags)
 {
        if (net == NULL || addr == 0 || port == 0) {
                return kr_error(EINVAL);
@@ -302,27 +304,24 @@ int network_listen(struct network *net, const char *addr, uint16_t port, uint32_
 
        /* Parse address. */
        int ret = 0;
-       struct sockaddr_storage sa;
+       union inaddr sa;
        if (strchr(addr, ':') != NULL) {
-               ret = uv_ip6_addr(addr, port, (struct sockaddr_in6 *)&sa);
+               ret = uv_ip6_addr(addr, port, &sa.ip6);
        } else {
-               ret = uv_ip4_addr(addr, port, (struct sockaddr_in *)&sa);
+               ret = uv_ip4_addr(addr, port, &sa.ip4);
        }
        if (ret != 0) {
                return ret;
        }
 
-       /* Bind interfaces */
-       struct endpoint *ep = malloc(sizeof(*ep));
-       memset(ep, 0, sizeof(*ep));
-       ep->flags = NET_DOWN;
-       ep->port = port;
-       ret = open_endpoint(net, ep, (struct sockaddr *)&sa, flags);
-       if (ret == 0) {
-               ret = insert_endpoint(net, addr, ep);
-       }
-       if (ret != 0) {
-               close_endpoint(ep, false);
+       if ((flags & NET_UDP) && (flags & NET_TCP)) {
+               /* We accept  ^^ this shorthand at this API layer. */
+               ret = create_endpoint(net, addr, port, flags & ~NET_TCP, &sa.ip, -1);
+               if (ret == 0) {
+                       ret = create_endpoint(net, addr, port, flags & ~NET_UDP, &sa.ip, -1);
+               }
+       } else {
+               ret = create_endpoint(net, addr, port, flags, &sa.ip, -1);
        }
 
        return ret;
@@ -376,8 +375,8 @@ static int set_bpf_cb(const char *key, void *val, void *ext)
        for (size_t i = 0; i < endpoints->len; i++) {
                struct endpoint *endpoint = endpoints->at[i];
                uv_os_fd_t sockfd = -1;
-               if (endpoint->tcp != NULL) uv_fileno((const uv_handle_t *)endpoint->tcp, &sockfd);
-               if (endpoint->udp != NULL) uv_fileno((const uv_handle_t *)endpoint->udp, &sockfd);
+               if (endpoint->handle != NULL)
+                       uv_fileno(endpoint->handle, &sockfd);
                assert(sockfd != -1);
 
                if (setsockopt(sockfd, SOL_SOCKET, SO_ATTACH_BPF, bpffd, sizeof(int)) != 0) {
@@ -414,8 +413,8 @@ static int clear_bpf_cb(const char *key, void *val, void *ext)
        for (size_t i = 0; i < endpoints->len; i++) {
                struct endpoint *endpoint = endpoints->at[i];
                uv_os_fd_t sockfd = -1;
-               if (endpoint->tcp != NULL) uv_fileno((const uv_handle_t *)endpoint->tcp, &sockfd);
-               if (endpoint->udp != NULL) uv_fileno((const uv_handle_t *)endpoint->udp, &sockfd);
+               if (endpoint->handle != NULL)
+                       uv_fileno(endpoint->handle, &sockfd);
                assert(sockfd != -1);
 
                if (setsockopt(sockfd, SOL_SOCKET, SO_DETACH_BPF, NULL, 0) != 0) {
index 1e80d09a5a3a05bcbcd98e9eba08193cc821248b..61b9835c4d0b68c16a0081adae1b771597cabb69 100644 (file)
 struct engine;
 
 enum endpoint_flag {
-    NET_DOWN = 0 << 0,
-    NET_UDP  = 1 << 0,
-    NET_TCP  = 1 << 1,
-    NET_TLS  = 1 << 2,
+       NET_DOWN = 0,
+       NET_UDP  = 1 << 0,
+       NET_TCP  = 1 << 1,
+       NET_TLS  = 1 << 2, /**< only used together with NET_TCP */
 };
 
+/** Wrapper for a single socket to listen on. */
 struct endpoint {
-    uv_udp_t *udp;
-    uv_tcp_t *tcp;
-    uint16_t port;
-    uint16_t flags;
+       uv_handle_t *handle; /** uv_udp_t or uv_tcp_t */
+       uint16_t port;
+       uint16_t flags; /**< see enum endpoint_flag; (_UDP | _TCP) *not* allowed */
 };
 
 /** @cond internal Array of endpoints */
@@ -52,7 +52,12 @@ struct net_tcp_param {
 
 struct network {
        uv_loop_t *loop;
+
+       /** Map: address string -> endpoint_array_t.
+        * \note even same address-port-flags tuples may appear.
+        * TODO: trie_t, keyed on *binary* address-port pair. */
        map_t endpoints;
+
        struct tls_credentials *tls_credentials;
        tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */
        struct tls_session_ticket_ctx *tls_session_ticket_ctx;
@@ -62,8 +67,14 @@ struct network {
 
 void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog);
 void network_deinit(struct network *net);
+
+/** Start listenting on addr#port.
+ * \param flags see enum endpoint_flag; (NET_UDP | NET_TCP) is allowed. */
+int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags);
+
+/** Start listenting on an open file-descriptor. */
 int network_listen_fd(struct network *net, int fd, bool use_tls);
-int network_listen(struct network *net, const char *addr, uint16_t port, uint32_t flags);
+
 int network_close(struct network *net, const char *addr, uint16_t port);
 int network_set_tls_cert(struct network *net, const char *cert);
 int network_set_tls_key(struct network *net, const char *key);