]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon/network: enum endpoint_flag -> endpoint_flags_t
authorVladimír Čunát <vladimir.cunat@nic.cz>
Tue, 12 Mar 2019 09:28:53 +0000 (10:28 +0100)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Tue, 12 Mar 2019 11:42:03 +0000 (12:42 +0100)
The dual UDP+TCP is completely eliminated except for the externally
exposed "APIs" - lua net.listen() and command-line parameters.

daemon/bindings/net.c
daemon/main.c
daemon/network.c
daemon/network.h

index 0419aa9ad7b51866c55ca2ae4c27f500692a148f..c6a0ca70394d0eab7b6d1f10b691bcaf5d9c8a9d 100644 (file)
@@ -37,25 +37,29 @@ static int net_list_add(const char *key, void *val, void *ext)
                lua_setfield(L, -2, "ip");
 
                lua_newtable(L);  // "transport" table
-               if (ep->flags & NET_UDP) {
+               switch (ep->flags.sock_type) {
+               case SOCK_DGRAM:
                        lua_pushliteral(L, "udp");
                        lua_setfield(L, -2, "protocol");
                        lua_pushinteger(L, ep->port);
                        lua_setfield(L, -2, "port");
                        lua_pushliteral(L, "none");
                        lua_setfield(L, -2, "security");
-               } else if (ep->flags & NET_TCP) {
+                       break;
+               case SOCK_STREAM:
                        lua_pushliteral(L, "tcp");
                        lua_setfield(L, -2, "protocol");
                        lua_pushinteger(L, ep->port);
                        lua_setfield(L, -2, "port");
-                       if (ep->flags & NET_TLS) {
+                       if (ep->flags.tls) {
                                lua_pushliteral(L, "tls");
                        } else {
                                lua_pushliteral(L, "none");
                        }
                        lua_setfield(L, -2, "security");
-               } else {
+                       break;
+               default:
+                       assert(!EINVAL);
                        lua_pushliteral(L, "unknown");
                        lua_setfield(L, -2, "protocol");
                }
@@ -85,7 +89,7 @@ static int net_list(lua_State *L)
 }
 
 /** Listen on an address list represented by the top of lua stack. */
-static int net_listen_addrs(lua_State *L, int port, int flags)
+static int net_listen_addrs(lua_State *L, int port, bool tls)
 {
        /* Case: table with 'addr' field; only follow that field directly. */
        lua_getfield(L, -1, "addr");
@@ -99,7 +103,16 @@ static int net_listen_addrs(lua_State *L, int port, int flags)
        const char *str = lua_tostring(L, -1);
        if (str != NULL) {
                struct engine *engine = engine_luaget(L);
-               int ret = network_listen(&engine->net, str, port, flags);
+               endpoint_flags_t flags = { .tls = tls };
+               int ret = 0;
+               if (!tls) {
+                       flags.sock_type = SOCK_DGRAM;
+                       ret = network_listen(&engine->net, str, port, flags);
+               }
+               if (ret == 0) { /* common for TCP and TLS */
+                       flags.sock_type = SOCK_STREAM;
+                       ret = network_listen(&engine->net, str, port, flags);
+               }
                if (ret != 0) {
                        kr_log_info("[system] bind to '%s@%d' %s\n",
                                        str, port, kr_strerror(ret));
@@ -112,7 +125,7 @@ static int net_listen_addrs(lua_State *L, int port, int flags)
                lua_error_p(L, "bad type for address");
        lua_pushnil(L);
        while (lua_next(L, -2)) {
-               if (net_listen_addrs(L, port, flags) == 0)
+               if (net_listen_addrs(L, port, tls) == 0)
                        return 0;
                lua_pop(L, 1);
        }
@@ -150,11 +163,10 @@ static int net_listen(lua_State *L)
        if (n > 2 && lua_istable(L, 3)) {
                tls = table_get_flag(L, 3, "tls", tls);
        }
-       int flags = tls ? (NET_TCP|NET_TLS) : (NET_TCP|NET_UDP);
 
        /* Now focus on the first argument. */
        lua_pop(L, n - 1);
-       int res = net_listen_addrs(L, port, flags);
+       int res = net_listen_addrs(L, port, tls);
        lua_pushboolean(L, res);
        return res;
 }
@@ -168,10 +180,27 @@ static int net_close(lua_State *L)
                lua_error_p(L, "expected 'close(string addr, number port)'");
 
        /* Open resolution context cache */
-       struct engine *engine = engine_luaget(L);
-       int ret = network_close(&engine->net, lua_tostring(L, 1), lua_tointeger(L, 2),
-                               lua_tointeger(L, 3)/* 0 if not number-like */);
-       lua_pushboolean(L, ret == 0);
+       struct network *net = &engine_luaget(L)->net;
+       const char *addr = lua_tostring(L, 1);
+       const uint16_t port = lua_tointeger(L, 2);
+       endpoint_flags_t flags_all[] = {
+               { .sock_type = SOCK_DGRAM,  .tls = false },
+               { .sock_type = SOCK_STREAM, .tls = false },
+               { .sock_type = SOCK_STREAM, .tls = true },
+       };
+       bool success = false; /*< at least one deletion succeeded */
+       int ret = 0;
+       for (int i = 0; i < sizeof(flags_all) / sizeof(flags_all[0]); ++i) {
+               ret = network_close(net, addr, port, flags_all[i]);
+               if (ret == 0) {
+                       success = true;
+               } else if (ret != kr_error(ENOENT)) {
+                       break;
+               }
+               ret = 0;
+       }
+       /* true: no fatal error and at least one kr_ok() */
+       lua_pushboolean(L, ret == 0 && success);
        return 1;
 }
 
index f3d02d05a57a82abd2e865e17a6ded9e63d789e4..4f94c1e180c7873abe2fa39cb9bd4deafdd11e81 100644 (file)
@@ -628,13 +628,21 @@ static int bind_fds(struct network *net, fd_array_t *fd_set, bool tls) {
 }
 
 static int bind_sockets(struct network *net, addr_array_t *addr_set, bool tls) {
-       uint32_t flags = tls ? NET_TCP|NET_TLS : NET_UDP|NET_TCP;
+       endpoint_flags_t flags = { .tls = tls };
        for (size_t i = 0; i < addr_set->len; ++i) {
                uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT;
                char addr_str[INET6_ADDRSTRLEN + 1];
                int ret = kr_straddr_split(addr_set->at[i], addr_str, &port);
-               if (ret == 0)
+
+               if (ret == 0 && !tls) {
+                       flags.sock_type = SOCK_DGRAM;
+                       ret = network_listen(net, addr_str, port, flags);
+               }
+               if (ret == 0) { /* common for TCP and TLS */
+                       flags.sock_type = SOCK_STREAM;
                        ret = network_listen(net, addr_str, port, flags);
+               }
+
                if (ret != 0) {
                        kr_log_error("[system] bind to '%s' %s%s\n",
                                addr_set->at[i], tls ? "(TLS) " : "", kr_strerror(ret));
index 8b3c9721dfeae4db389f8217f4924a5218d1d742..d76a37af754b3ed1358065ce99c30c632768d8cb 100644 (file)
@@ -151,8 +151,8 @@ static int open_endpoint(struct network *net, struct endpoint *ep,
                return kr_error(EEXIST);
        }
 
-       if (ep->flags & NET_UDP) {
-               if (ep->flags & (NET_TCP | NET_TLS)) {
+       if (ep->flags.sock_type == SOCK_DGRAM) {
+               if (ep->flags.tls) {
                        assert(!EINVAL);
                        return kr_error(EINVAL);
                }
@@ -174,7 +174,7 @@ static int open_endpoint(struct network *net, struct endpoint *ep,
                }
        } /* else */
 
-       if (ep->flags & NET_TCP) {
+       if (ep->flags.sock_type == SOCK_STREAM) {
                uv_tcp_t *ep_handle = calloc(1, sizeof(uv_tcp_t));
                ep->handle = (uv_handle_t *)ep_handle;
                if (!ep->handle) {
@@ -188,7 +188,7 @@ static int open_endpoint(struct network *net, struct endpoint *ep,
                                return ret;
                        }
                }
-               if (ep->flags & NET_TLS) {
+               if (ep->flags.tls) {
                        return sa
                                ? tcp_bind_tls  (ep_handle, sa, net->tcp_backlog)
                                : tcp_bindfd_tls(ep_handle, fd, net->tcp_backlog);
@@ -205,13 +205,13 @@ static int open_endpoint(struct network *net, struct endpoint *ep,
 
 /** @internal Fetch endpoint array and offset of the address/port query. */
 static endpoint_array_t *network_get(struct network *net, const char *addr, uint16_t port,
-                                       uint16_t flags, size_t *index)
+                                       endpoint_flags_t flags, size_t *index)
 {
        endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
        if (ep_array) {
                for (size_t i = ep_array->len; i--;) {
                        struct endpoint *ep = ep_array->at[i];
-                       if (ep->port == port && ep->flags == flags) {
+                       if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) {
                                *index = i;
                                return ep_array;
                        }
@@ -222,7 +222,7 @@ static endpoint_array_t *network_get(struct network *net, const char *addr, uint
 
 /** \note pass either sa != NULL xor fd != -1 */
 static int create_endpoint(struct network *net, const char *addr_str,
-                               uint16_t port, uint16_t flags,
+                               uint16_t port, endpoint_flags_t flags,
                                const struct sockaddr *sa, int fd)
 {
        /* Bind interfaces */
@@ -230,8 +230,8 @@ static int create_endpoint(struct network *net, const char *addr_str,
        if (!ep) {
                return kr_error(ENOMEM);
        }
-       ep->flags = flags;
        ep->port = port;
+       ep->flags = flags;
        int ret = open_endpoint(net, ep, sa, fd);
        if (ret == 0) {
                ret = insert_endpoint(net, addr_str, ep);
@@ -245,25 +245,17 @@ static int create_endpoint(struct network *net, const char *addr_str,
 int network_listen_fd(struct network *net, int fd, bool use_tls)
 {
        /* Extract fd's socket type. */
-       int sock_type;
-       socklen_t len = sizeof(sock_type);
-       int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &len);
+       endpoint_flags_t flags = { .tls = use_tls };
+       socklen_t len = sizeof(flags.sock_type);
+       int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len);
        if (ret != 0) {
                return kr_error(errno);
        }
-       uint16_t flags;
-       if (sock_type == SOCK_DGRAM) {
-               flags = NET_UDP;
-               if (use_tls) {
-                       assert(!EINVAL);
-                       return kr_error(EINVAL);
-               }
-       } else if (sock_type == SOCK_STREAM) {
-               flags = NET_TCP;
-               if (use_tls) {
-                       flags |= NET_TLS;
-               }
-       } else {
+       if (flags.sock_type == SOCK_DGRAM && use_tls) {
+               assert(!EINVAL);
+               return kr_error(EINVAL);
+       }
+       if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) {
                return kr_error(EBADF);
        }
 
@@ -291,16 +283,16 @@ int network_listen_fd(struct network *net, int fd, bool use_tls)
        return create_endpoint(net, addr_str, port, flags, NULL, fd);
 }
 
-int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags)
+int network_listen(struct network *net, const char *addr, uint16_t port,
+                  endpoint_flags_t flags)
 {
        if (net == NULL || addr == 0 || port == 0) {
+               assert(!EINVAL);
                return kr_error(EINVAL);
        }
-
-       /* Already listening */
        size_t index = 0;
        if (network_get(net, addr, port, flags, &index)) {
-               return kr_ok();
+               return kr_ok(); /* Already listening */
        }
 
        /* Parse address. */
@@ -314,21 +306,11 @@ int network_listen(struct network *net, const char *addr, uint16_t port, uint16_
        if (ret != 0) {
                return ret;
        }
-
-       if ((flags & NET_UDP) && (flags & NET_TCP)) {
-               /* We accept  ^^ this shorthand at this API layer. */
-               ret = create_endpoint(net, addr, port, flags & ~NET_TCP, &sa.ip, -1);
-               if (ret == 0) {
-                       ret = create_endpoint(net, addr, port, flags & ~NET_UDP, &sa.ip, -1);
-               }
-       } else {
-               ret = create_endpoint(net, addr, port, flags, &sa.ip, -1);
-       }
-
-       return ret;
+       return create_endpoint(net, addr, port, flags, &sa.ip, -1);
 }
 
-int network_close(struct network *net, const char *addr, uint16_t port, uint16_t flags)
+int network_close(struct network *net, const char *addr, uint16_t port,
+                 endpoint_flags_t flags)
 {
        endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
        if (!ep_array) {
@@ -339,7 +321,7 @@ int network_close(struct network *net, const char *addr, uint16_t port, uint16_t
        bool matched = false;
        while (i < ep_array->len) {
                struct endpoint *ep = ep_array->at[i];
-               if (!flags || flags == ep->flags) {
+               if (endpoint_flags_eq(flags, ep->flags)) {
                        close_endpoint(ep, false);
                        array_del(*ep_array, i);
                        matched = true;
index 36b88855ee98f53b558b4582a4074e8dbffc8972..7167ab44f812788886f93403f42c2fc38b44452b 100644 (file)
 
 struct engine;
 
-enum endpoint_flag {
-       NET_DOWN = 0,
-       NET_UDP  = 1 << 0,
-       NET_TCP  = 1 << 1,
-       NET_TLS  = 1 << 2, /**< only used together with NET_TCP */
-};
+/** Ways to listen for DNS on a port. */
+typedef struct {
+       int sock_type;  /**< SOCK_DGRAM or SOCK_STREAM */
+       bool tls;       /**< only used together with .tcp */
+} endpoint_flags_t;
+
+static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2)
+{
+       /* memcmp() would typically work, but there's no guarantee. */
+       return f1.sock_type == f2.sock_type && f1.tls == f2.tls;
+}
 
 /** Wrapper for a single socket to listen on. */
 struct endpoint {
        uv_handle_t *handle; /** uv_udp_t or uv_tcp_t */
        uint16_t port;
-       uint16_t flags; /**< see enum endpoint_flag; (_UDP | _TCP) *not* allowed */
+       endpoint_flags_t flags;
 };
 
 /** @cond internal Array of endpoints */
@@ -68,17 +73,20 @@ struct network {
 void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog);
 void network_deinit(struct network *net);
 
-/** Start listenting on addr#port.
- * \param flags see enum endpoint_flag; (NET_UDP | NET_TCP) is allowed.
- * \note if we did listen already, nothing is done and kr_ok() is returned. */
-int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags);
+/** Start listenting on addr#port with flags.
+ * \note if we did listen on that combination already,
+ *       nothing is done and kr_ok() is returned.
+ * \note there's no short-hand to listen both on UDP and TCP. */
+int network_listen(struct network *net, const char *addr, uint16_t port,
+                  endpoint_flags_t flags);
 
 /** Start listenting on an open file-descriptor. */
 int network_listen_fd(struct network *net, int fd, bool use_tls);
 
-/** Stop listening on all addr#port with equal flags; flags == 0 means all of them.
+/** Stop listening on all addr#port with equal flags.
  * \return kr_error(ENOENT) if nothing matched. */
-int network_close(struct network *net, const char *addr, uint16_t port, uint16_t flags);
+int network_close(struct network *net, const char *addr, uint16_t port,
+                 endpoint_flags_t flags);
 
 int network_set_tls_cert(struct network *net, const char *cert);
 int network_set_tls_key(struct network *net, const char *key);