From: Vladimír Čunát Date: Tue, 12 Mar 2019 09:28:53 +0000 (+0100) Subject: daemon/network: enum endpoint_flag -> endpoint_flags_t X-Git-Tag: v4.0.0~21^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9fb5866a7fe443cb5007de31edc86244ce152ff4;p=thirdparty%2Fknot-resolver.git daemon/network: enum endpoint_flag -> endpoint_flags_t The dual UDP+TCP is completely eliminated except for the externally exposed "APIs" - lua net.listen() and command-line parameters. --- diff --git a/daemon/bindings/net.c b/daemon/bindings/net.c index 0419aa9ad..c6a0ca703 100644 --- a/daemon/bindings/net.c +++ b/daemon/bindings/net.c @@ -37,25 +37,29 @@ static int net_list_add(const char *key, void *val, void *ext) lua_setfield(L, -2, "ip"); lua_newtable(L); // "transport" table - if (ep->flags & NET_UDP) { + switch (ep->flags.sock_type) { + case SOCK_DGRAM: lua_pushliteral(L, "udp"); lua_setfield(L, -2, "protocol"); lua_pushinteger(L, ep->port); lua_setfield(L, -2, "port"); lua_pushliteral(L, "none"); lua_setfield(L, -2, "security"); - } else if (ep->flags & NET_TCP) { + break; + case SOCK_STREAM: lua_pushliteral(L, "tcp"); lua_setfield(L, -2, "protocol"); lua_pushinteger(L, ep->port); lua_setfield(L, -2, "port"); - if (ep->flags & NET_TLS) { + if (ep->flags.tls) { lua_pushliteral(L, "tls"); } else { lua_pushliteral(L, "none"); } lua_setfield(L, -2, "security"); - } else { + break; + default: + assert(!EINVAL); lua_pushliteral(L, "unknown"); lua_setfield(L, -2, "protocol"); } @@ -85,7 +89,7 @@ static int net_list(lua_State *L) } /** Listen on an address list represented by the top of lua stack. */ -static int net_listen_addrs(lua_State *L, int port, int flags) +static int net_listen_addrs(lua_State *L, int port, bool tls) { /* Case: table with 'addr' field; only follow that field directly. */ lua_getfield(L, -1, "addr"); @@ -99,7 +103,16 @@ static int net_listen_addrs(lua_State *L, int port, int flags) const char *str = lua_tostring(L, -1); if (str != NULL) { struct engine *engine = engine_luaget(L); - int ret = network_listen(&engine->net, str, port, flags); + endpoint_flags_t flags = { .tls = tls }; + int ret = 0; + if (!tls) { + flags.sock_type = SOCK_DGRAM; + ret = network_listen(&engine->net, str, port, flags); + } + if (ret == 0) { /* common for TCP and TLS */ + flags.sock_type = SOCK_STREAM; + ret = network_listen(&engine->net, str, port, flags); + } if (ret != 0) { kr_log_info("[system] bind to '%s@%d' %s\n", str, port, kr_strerror(ret)); @@ -112,7 +125,7 @@ static int net_listen_addrs(lua_State *L, int port, int flags) lua_error_p(L, "bad type for address"); lua_pushnil(L); while (lua_next(L, -2)) { - if (net_listen_addrs(L, port, flags) == 0) + if (net_listen_addrs(L, port, tls) == 0) return 0; lua_pop(L, 1); } @@ -150,11 +163,10 @@ static int net_listen(lua_State *L) if (n > 2 && lua_istable(L, 3)) { tls = table_get_flag(L, 3, "tls", tls); } - int flags = tls ? (NET_TCP|NET_TLS) : (NET_TCP|NET_UDP); /* Now focus on the first argument. */ lua_pop(L, n - 1); - int res = net_listen_addrs(L, port, flags); + int res = net_listen_addrs(L, port, tls); lua_pushboolean(L, res); return res; } @@ -168,10 +180,27 @@ static int net_close(lua_State *L) lua_error_p(L, "expected 'close(string addr, number port)'"); /* Open resolution context cache */ - struct engine *engine = engine_luaget(L); - int ret = network_close(&engine->net, lua_tostring(L, 1), lua_tointeger(L, 2), - lua_tointeger(L, 3)/* 0 if not number-like */); - lua_pushboolean(L, ret == 0); + struct network *net = &engine_luaget(L)->net; + const char *addr = lua_tostring(L, 1); + const uint16_t port = lua_tointeger(L, 2); + endpoint_flags_t flags_all[] = { + { .sock_type = SOCK_DGRAM, .tls = false }, + { .sock_type = SOCK_STREAM, .tls = false }, + { .sock_type = SOCK_STREAM, .tls = true }, + }; + bool success = false; /*< at least one deletion succeeded */ + int ret = 0; + for (int i = 0; i < sizeof(flags_all) / sizeof(flags_all[0]); ++i) { + ret = network_close(net, addr, port, flags_all[i]); + if (ret == 0) { + success = true; + } else if (ret != kr_error(ENOENT)) { + break; + } + ret = 0; + } + /* true: no fatal error and at least one kr_ok() */ + lua_pushboolean(L, ret == 0 && success); return 1; } diff --git a/daemon/main.c b/daemon/main.c index f3d02d05a..4f94c1e18 100644 --- a/daemon/main.c +++ b/daemon/main.c @@ -628,13 +628,21 @@ static int bind_fds(struct network *net, fd_array_t *fd_set, bool tls) { } static int bind_sockets(struct network *net, addr_array_t *addr_set, bool tls) { - uint32_t flags = tls ? NET_TCP|NET_TLS : NET_UDP|NET_TCP; + endpoint_flags_t flags = { .tls = tls }; for (size_t i = 0; i < addr_set->len; ++i) { uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT; char addr_str[INET6_ADDRSTRLEN + 1]; int ret = kr_straddr_split(addr_set->at[i], addr_str, &port); - if (ret == 0) + + if (ret == 0 && !tls) { + flags.sock_type = SOCK_DGRAM; + ret = network_listen(net, addr_str, port, flags); + } + if (ret == 0) { /* common for TCP and TLS */ + flags.sock_type = SOCK_STREAM; ret = network_listen(net, addr_str, port, flags); + } + if (ret != 0) { kr_log_error("[system] bind to '%s' %s%s\n", addr_set->at[i], tls ? "(TLS) " : "", kr_strerror(ret)); diff --git a/daemon/network.c b/daemon/network.c index 8b3c9721d..d76a37af7 100644 --- a/daemon/network.c +++ b/daemon/network.c @@ -151,8 +151,8 @@ static int open_endpoint(struct network *net, struct endpoint *ep, return kr_error(EEXIST); } - if (ep->flags & NET_UDP) { - if (ep->flags & (NET_TCP | NET_TLS)) { + if (ep->flags.sock_type == SOCK_DGRAM) { + if (ep->flags.tls) { assert(!EINVAL); return kr_error(EINVAL); } @@ -174,7 +174,7 @@ static int open_endpoint(struct network *net, struct endpoint *ep, } } /* else */ - if (ep->flags & NET_TCP) { + if (ep->flags.sock_type == SOCK_STREAM) { uv_tcp_t *ep_handle = calloc(1, sizeof(uv_tcp_t)); ep->handle = (uv_handle_t *)ep_handle; if (!ep->handle) { @@ -188,7 +188,7 @@ static int open_endpoint(struct network *net, struct endpoint *ep, return ret; } } - if (ep->flags & NET_TLS) { + if (ep->flags.tls) { return sa ? tcp_bind_tls (ep_handle, sa, net->tcp_backlog) : tcp_bindfd_tls(ep_handle, fd, net->tcp_backlog); @@ -205,13 +205,13 @@ static int open_endpoint(struct network *net, struct endpoint *ep, /** @internal Fetch endpoint array and offset of the address/port query. */ static endpoint_array_t *network_get(struct network *net, const char *addr, uint16_t port, - uint16_t flags, size_t *index) + endpoint_flags_t flags, size_t *index) { endpoint_array_t *ep_array = map_get(&net->endpoints, addr); if (ep_array) { for (size_t i = ep_array->len; i--;) { struct endpoint *ep = ep_array->at[i]; - if (ep->port == port && ep->flags == flags) { + if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) { *index = i; return ep_array; } @@ -222,7 +222,7 @@ static endpoint_array_t *network_get(struct network *net, const char *addr, uint /** \note pass either sa != NULL xor fd != -1 */ static int create_endpoint(struct network *net, const char *addr_str, - uint16_t port, uint16_t flags, + uint16_t port, endpoint_flags_t flags, const struct sockaddr *sa, int fd) { /* Bind interfaces */ @@ -230,8 +230,8 @@ static int create_endpoint(struct network *net, const char *addr_str, if (!ep) { return kr_error(ENOMEM); } - ep->flags = flags; ep->port = port; + ep->flags = flags; int ret = open_endpoint(net, ep, sa, fd); if (ret == 0) { ret = insert_endpoint(net, addr_str, ep); @@ -245,25 +245,17 @@ static int create_endpoint(struct network *net, const char *addr_str, int network_listen_fd(struct network *net, int fd, bool use_tls) { /* Extract fd's socket type. */ - int sock_type; - socklen_t len = sizeof(sock_type); - int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &len); + endpoint_flags_t flags = { .tls = use_tls }; + socklen_t len = sizeof(flags.sock_type); + int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len); if (ret != 0) { return kr_error(errno); } - uint16_t flags; - if (sock_type == SOCK_DGRAM) { - flags = NET_UDP; - if (use_tls) { - assert(!EINVAL); - return kr_error(EINVAL); - } - } else if (sock_type == SOCK_STREAM) { - flags = NET_TCP; - if (use_tls) { - flags |= NET_TLS; - } - } else { + if (flags.sock_type == SOCK_DGRAM && use_tls) { + assert(!EINVAL); + return kr_error(EINVAL); + } + if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) { return kr_error(EBADF); } @@ -291,16 +283,16 @@ int network_listen_fd(struct network *net, int fd, bool use_tls) return create_endpoint(net, addr_str, port, flags, NULL, fd); } -int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags) +int network_listen(struct network *net, const char *addr, uint16_t port, + endpoint_flags_t flags) { if (net == NULL || addr == 0 || port == 0) { + assert(!EINVAL); return kr_error(EINVAL); } - - /* Already listening */ size_t index = 0; if (network_get(net, addr, port, flags, &index)) { - return kr_ok(); + return kr_ok(); /* Already listening */ } /* Parse address. */ @@ -314,21 +306,11 @@ int network_listen(struct network *net, const char *addr, uint16_t port, uint16_ if (ret != 0) { return ret; } - - if ((flags & NET_UDP) && (flags & NET_TCP)) { - /* We accept ^^ this shorthand at this API layer. */ - ret = create_endpoint(net, addr, port, flags & ~NET_TCP, &sa.ip, -1); - if (ret == 0) { - ret = create_endpoint(net, addr, port, flags & ~NET_UDP, &sa.ip, -1); - } - } else { - ret = create_endpoint(net, addr, port, flags, &sa.ip, -1); - } - - return ret; + return create_endpoint(net, addr, port, flags, &sa.ip, -1); } -int network_close(struct network *net, const char *addr, uint16_t port, uint16_t flags) +int network_close(struct network *net, const char *addr, uint16_t port, + endpoint_flags_t flags) { endpoint_array_t *ep_array = map_get(&net->endpoints, addr); if (!ep_array) { @@ -339,7 +321,7 @@ int network_close(struct network *net, const char *addr, uint16_t port, uint16_t bool matched = false; while (i < ep_array->len) { struct endpoint *ep = ep_array->at[i]; - if (!flags || flags == ep->flags) { + if (endpoint_flags_eq(flags, ep->flags)) { close_endpoint(ep, false); array_del(*ep_array, i); matched = true; diff --git a/daemon/network.h b/daemon/network.h index 36b88855e..7167ab44f 100644 --- a/daemon/network.h +++ b/daemon/network.h @@ -27,18 +27,23 @@ struct engine; -enum endpoint_flag { - NET_DOWN = 0, - NET_UDP = 1 << 0, - NET_TCP = 1 << 1, - NET_TLS = 1 << 2, /**< only used together with NET_TCP */ -}; +/** Ways to listen for DNS on a port. */ +typedef struct { + int sock_type; /**< SOCK_DGRAM or SOCK_STREAM */ + bool tls; /**< only used together with .tcp */ +} endpoint_flags_t; + +static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2) +{ + /* memcmp() would typically work, but there's no guarantee. */ + return f1.sock_type == f2.sock_type && f1.tls == f2.tls; +} /** Wrapper for a single socket to listen on. */ struct endpoint { uv_handle_t *handle; /** uv_udp_t or uv_tcp_t */ uint16_t port; - uint16_t flags; /**< see enum endpoint_flag; (_UDP | _TCP) *not* allowed */ + endpoint_flags_t flags; }; /** @cond internal Array of endpoints */ @@ -68,17 +73,20 @@ struct network { void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog); void network_deinit(struct network *net); -/** Start listenting on addr#port. - * \param flags see enum endpoint_flag; (NET_UDP | NET_TCP) is allowed. - * \note if we did listen already, nothing is done and kr_ok() is returned. */ -int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags); +/** Start listenting on addr#port with flags. + * \note if we did listen on that combination already, + * nothing is done and kr_ok() is returned. + * \note there's no short-hand to listen both on UDP and TCP. */ +int network_listen(struct network *net, const char *addr, uint16_t port, + endpoint_flags_t flags); /** Start listenting on an open file-descriptor. */ int network_listen_fd(struct network *net, int fd, bool use_tls); -/** Stop listening on all addr#port with equal flags; flags == 0 means all of them. +/** Stop listening on all addr#port with equal flags. * \return kr_error(ENOENT) if nothing matched. */ -int network_close(struct network *net, const char *addr, uint16_t port, uint16_t flags); +int network_close(struct network *net, const char *addr, uint16_t port, + endpoint_flags_t flags); int network_set_tls_cert(struct network *net, const char *cert); int network_set_tls_key(struct network *net, const char *key);