lua_setfield(L, -2, "ip");
lua_newtable(L); // "transport" table
- if (ep->flags & NET_UDP) {
+ switch (ep->flags.sock_type) {
+ case SOCK_DGRAM:
lua_pushliteral(L, "udp");
lua_setfield(L, -2, "protocol");
lua_pushinteger(L, ep->port);
lua_setfield(L, -2, "port");
lua_pushliteral(L, "none");
lua_setfield(L, -2, "security");
- } else if (ep->flags & NET_TCP) {
+ break;
+ case SOCK_STREAM:
lua_pushliteral(L, "tcp");
lua_setfield(L, -2, "protocol");
lua_pushinteger(L, ep->port);
lua_setfield(L, -2, "port");
- if (ep->flags & NET_TLS) {
+ if (ep->flags.tls) {
lua_pushliteral(L, "tls");
} else {
lua_pushliteral(L, "none");
}
lua_setfield(L, -2, "security");
- } else {
+ break;
+ default:
+ assert(!EINVAL);
lua_pushliteral(L, "unknown");
lua_setfield(L, -2, "protocol");
}
}
/** Listen on an address list represented by the top of lua stack. */
-static int net_listen_addrs(lua_State *L, int port, int flags)
+static int net_listen_addrs(lua_State *L, int port, bool tls)
{
/* Case: table with 'addr' field; only follow that field directly. */
lua_getfield(L, -1, "addr");
const char *str = lua_tostring(L, -1);
if (str != NULL) {
struct engine *engine = engine_luaget(L);
- int ret = network_listen(&engine->net, str, port, flags);
+ endpoint_flags_t flags = { .tls = tls };
+ int ret = 0;
+ if (!tls) {
+ flags.sock_type = SOCK_DGRAM;
+ ret = network_listen(&engine->net, str, port, flags);
+ }
+ if (ret == 0) { /* common for TCP and TLS */
+ flags.sock_type = SOCK_STREAM;
+ ret = network_listen(&engine->net, str, port, flags);
+ }
if (ret != 0) {
kr_log_info("[system] bind to '%s@%d' %s\n",
str, port, kr_strerror(ret));
lua_error_p(L, "bad type for address");
lua_pushnil(L);
while (lua_next(L, -2)) {
- if (net_listen_addrs(L, port, flags) == 0)
+ if (net_listen_addrs(L, port, tls) == 0)
return 0;
lua_pop(L, 1);
}
if (n > 2 && lua_istable(L, 3)) {
tls = table_get_flag(L, 3, "tls", tls);
}
- int flags = tls ? (NET_TCP|NET_TLS) : (NET_TCP|NET_UDP);
/* Now focus on the first argument. */
lua_pop(L, n - 1);
- int res = net_listen_addrs(L, port, flags);
+ int res = net_listen_addrs(L, port, tls);
lua_pushboolean(L, res);
return res;
}
lua_error_p(L, "expected 'close(string addr, number port)'");
/* Open resolution context cache */
- struct engine *engine = engine_luaget(L);
- int ret = network_close(&engine->net, lua_tostring(L, 1), lua_tointeger(L, 2),
- lua_tointeger(L, 3)/* 0 if not number-like */);
- lua_pushboolean(L, ret == 0);
+ struct network *net = &engine_luaget(L)->net;
+ const char *addr = lua_tostring(L, 1);
+ const uint16_t port = lua_tointeger(L, 2);
+ endpoint_flags_t flags_all[] = {
+ { .sock_type = SOCK_DGRAM, .tls = false },
+ { .sock_type = SOCK_STREAM, .tls = false },
+ { .sock_type = SOCK_STREAM, .tls = true },
+ };
+ bool success = false; /*< at least one deletion succeeded */
+ int ret = 0;
+ for (int i = 0; i < sizeof(flags_all) / sizeof(flags_all[0]); ++i) {
+ ret = network_close(net, addr, port, flags_all[i]);
+ if (ret == 0) {
+ success = true;
+ } else if (ret != kr_error(ENOENT)) {
+ break;
+ }
+ ret = 0;
+ }
+ /* true: no fatal error and at least one kr_ok() */
+ lua_pushboolean(L, ret == 0 && success);
return 1;
}
}
static int bind_sockets(struct network *net, addr_array_t *addr_set, bool tls) {
- uint32_t flags = tls ? NET_TCP|NET_TLS : NET_UDP|NET_TCP;
+ endpoint_flags_t flags = { .tls = tls };
for (size_t i = 0; i < addr_set->len; ++i) {
uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT;
char addr_str[INET6_ADDRSTRLEN + 1];
int ret = kr_straddr_split(addr_set->at[i], addr_str, &port);
- if (ret == 0)
+
+ if (ret == 0 && !tls) {
+ flags.sock_type = SOCK_DGRAM;
+ ret = network_listen(net, addr_str, port, flags);
+ }
+ if (ret == 0) { /* common for TCP and TLS */
+ flags.sock_type = SOCK_STREAM;
ret = network_listen(net, addr_str, port, flags);
+ }
+
if (ret != 0) {
kr_log_error("[system] bind to '%s' %s%s\n",
addr_set->at[i], tls ? "(TLS) " : "", kr_strerror(ret));
return kr_error(EEXIST);
}
- if (ep->flags & NET_UDP) {
- if (ep->flags & (NET_TCP | NET_TLS)) {
+ if (ep->flags.sock_type == SOCK_DGRAM) {
+ if (ep->flags.tls) {
assert(!EINVAL);
return kr_error(EINVAL);
}
}
} /* else */
- if (ep->flags & NET_TCP) {
+ if (ep->flags.sock_type == SOCK_STREAM) {
uv_tcp_t *ep_handle = calloc(1, sizeof(uv_tcp_t));
ep->handle = (uv_handle_t *)ep_handle;
if (!ep->handle) {
return ret;
}
}
- if (ep->flags & NET_TLS) {
+ if (ep->flags.tls) {
return sa
? tcp_bind_tls (ep_handle, sa, net->tcp_backlog)
: tcp_bindfd_tls(ep_handle, fd, net->tcp_backlog);
/** @internal Fetch endpoint array and offset of the address/port query. */
static endpoint_array_t *network_get(struct network *net, const char *addr, uint16_t port,
- uint16_t flags, size_t *index)
+ endpoint_flags_t flags, size_t *index)
{
endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
if (ep_array) {
for (size_t i = ep_array->len; i--;) {
struct endpoint *ep = ep_array->at[i];
- if (ep->port == port && ep->flags == flags) {
+ if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) {
*index = i;
return ep_array;
}
/** \note pass either sa != NULL xor fd != -1 */
static int create_endpoint(struct network *net, const char *addr_str,
- uint16_t port, uint16_t flags,
+ uint16_t port, endpoint_flags_t flags,
const struct sockaddr *sa, int fd)
{
/* Bind interfaces */
if (!ep) {
return kr_error(ENOMEM);
}
- ep->flags = flags;
ep->port = port;
+ ep->flags = flags;
int ret = open_endpoint(net, ep, sa, fd);
if (ret == 0) {
ret = insert_endpoint(net, addr_str, ep);
int network_listen_fd(struct network *net, int fd, bool use_tls)
{
/* Extract fd's socket type. */
- int sock_type;
- socklen_t len = sizeof(sock_type);
- int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &len);
+ endpoint_flags_t flags = { .tls = use_tls };
+ socklen_t len = sizeof(flags.sock_type);
+ int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len);
if (ret != 0) {
return kr_error(errno);
}
- uint16_t flags;
- if (sock_type == SOCK_DGRAM) {
- flags = NET_UDP;
- if (use_tls) {
- assert(!EINVAL);
- return kr_error(EINVAL);
- }
- } else if (sock_type == SOCK_STREAM) {
- flags = NET_TCP;
- if (use_tls) {
- flags |= NET_TLS;
- }
- } else {
+ if (flags.sock_type == SOCK_DGRAM && use_tls) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) {
return kr_error(EBADF);
}
return create_endpoint(net, addr_str, port, flags, NULL, fd);
}
-int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags)
+int network_listen(struct network *net, const char *addr, uint16_t port,
+ endpoint_flags_t flags)
{
if (net == NULL || addr == 0 || port == 0) {
+ assert(!EINVAL);
return kr_error(EINVAL);
}
-
- /* Already listening */
size_t index = 0;
if (network_get(net, addr, port, flags, &index)) {
- return kr_ok();
+ return kr_ok(); /* Already listening */
}
/* Parse address. */
if (ret != 0) {
return ret;
}
-
- if ((flags & NET_UDP) && (flags & NET_TCP)) {
- /* We accept ^^ this shorthand at this API layer. */
- ret = create_endpoint(net, addr, port, flags & ~NET_TCP, &sa.ip, -1);
- if (ret == 0) {
- ret = create_endpoint(net, addr, port, flags & ~NET_UDP, &sa.ip, -1);
- }
- } else {
- ret = create_endpoint(net, addr, port, flags, &sa.ip, -1);
- }
-
- return ret;
+ return create_endpoint(net, addr, port, flags, &sa.ip, -1);
}
-int network_close(struct network *net, const char *addr, uint16_t port, uint16_t flags)
+int network_close(struct network *net, const char *addr, uint16_t port,
+ endpoint_flags_t flags)
{
endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
if (!ep_array) {
bool matched = false;
while (i < ep_array->len) {
struct endpoint *ep = ep_array->at[i];
- if (!flags || flags == ep->flags) {
+ if (endpoint_flags_eq(flags, ep->flags)) {
close_endpoint(ep, false);
array_del(*ep_array, i);
matched = true;
struct engine;
-enum endpoint_flag {
- NET_DOWN = 0,
- NET_UDP = 1 << 0,
- NET_TCP = 1 << 1,
- NET_TLS = 1 << 2, /**< only used together with NET_TCP */
-};
+/** Ways to listen for DNS on a port. */
+typedef struct {
+ int sock_type; /**< SOCK_DGRAM or SOCK_STREAM */
+ bool tls; /**< only used together with .tcp */
+} endpoint_flags_t;
+
+static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2)
+{
+ /* memcmp() would typically work, but there's no guarantee. */
+ return f1.sock_type == f2.sock_type && f1.tls == f2.tls;
+}
/** Wrapper for a single socket to listen on. */
struct endpoint {
uv_handle_t *handle; /** uv_udp_t or uv_tcp_t */
uint16_t port;
- uint16_t flags; /**< see enum endpoint_flag; (_UDP | _TCP) *not* allowed */
+ endpoint_flags_t flags;
};
/** @cond internal Array of endpoints */
void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog);
void network_deinit(struct network *net);
-/** Start listenting on addr#port.
- * \param flags see enum endpoint_flag; (NET_UDP | NET_TCP) is allowed.
- * \note if we did listen already, nothing is done and kr_ok() is returned. */
-int network_listen(struct network *net, const char *addr, uint16_t port, uint16_t flags);
+/** Start listenting on addr#port with flags.
+ * \note if we did listen on that combination already,
+ * nothing is done and kr_ok() is returned.
+ * \note there's no short-hand to listen both on UDP and TCP. */
+int network_listen(struct network *net, const char *addr, uint16_t port,
+ endpoint_flags_t flags);
/** Start listenting on an open file-descriptor. */
int network_listen_fd(struct network *net, int fd, bool use_tls);
-/** Stop listening on all addr#port with equal flags; flags == 0 means all of them.
+/** Stop listening on all addr#port with equal flags.
* \return kr_error(ENOENT) if nothing matched. */
-int network_close(struct network *net, const char *addr, uint16_t port, uint16_t flags);
+int network_close(struct network *net, const char *addr, uint16_t port,
+ endpoint_flags_t flags);
int network_set_tls_cert(struct network *net, const char *cert);
int network_set_tls_key(struct network *net, const char *key);