From: Vladimír Čunát Date: Sat, 13 Apr 2019 08:25:20 +0000 (+0200) Subject: daemon: support for more endpoint kinds X-Git-Tag: v4.0.0~4^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2da7891322a99e81170055b45d1ac39968ad80d7;p=thirdparty%2Fknot-resolver.git daemon: support for more endpoint kinds --- diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a2cc0f946..511b81289 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -87,7 +87,7 @@ kres-gen: - docker script: - meson build_ci_lib --prefix=$PREFIX - - ninja -C build_ci_lib lib/libkres.so.${LIBKRES_ABI} + - ninja -C build_ci_lib daemon/kresd - ninja -C build_ci_lib kres-gen - git diff --quiet || (git diff; exit 1) # }}} diff --git a/daemon/bindings/net.c b/daemon/bindings/net.c index 30ba7971d..fc7f23a34 100644 --- a/daemon/bindings/net.c +++ b/daemon/bindings/net.c @@ -66,7 +66,11 @@ static int net_list_add(const char *key, void *val, void *ext) lua_setfield(L, -2, "transport"); lua_newtable(L); // "application" table - lua_pushliteral(L, "dns"); + if (ep->flags.kind) { + lua_pushstring(L, ep->flags.kind); + } else { + lua_pushliteral(L, "dns"); + } lua_setfield(L, -2, "protocol"); lua_setfield(L, -2, "application"); @@ -89,8 +93,9 @@ static int net_list(lua_State *L) } /** Listen on an address list represented by the top of lua stack. + * \note kind ownership is not transferred * \return success */ -static bool net_listen_addrs(lua_State *L, int port, bool tls) +static bool net_listen_addrs(lua_State *L, int port, bool tls, const char *kind) { /* Case: table with 'addr' field; only follow that field directly. */ lua_getfield(L, -1, "addr"); @@ -104,19 +109,25 @@ static bool net_listen_addrs(lua_State *L, int port, bool tls) const char *str = lua_tostring(L, -1); if (str != NULL) { struct engine *engine = engine_luaget(L); - endpoint_flags_t flags = { .tls = tls }; int ret = 0; - if (!tls) { + endpoint_flags_t flags = { .tls = tls }; + if (!kind && !flags.tls) { /* normal UDP */ flags.sock_type = SOCK_DGRAM; ret = network_listen(&engine->net, str, port, flags); } - if (ret == 0) { /* common for TCP and TLS */ + if (!kind && ret == 0) { /* common for normal TCP and TLS */ flags.sock_type = SOCK_STREAM; ret = network_listen(&engine->net, str, port, flags); } + if (kind) { + flags.kind = strdup(kind); + flags.sock_type = SOCK_STREAM; /* TODO: allow to override this? */ + ret = network_listen(&engine->net, str, port, flags); + } if (ret != 0) { - kr_log_info("[system] bind to '%s@%d' %s\n", - str, port, kr_strerror(ret)); + const char *stype = flags.sock_type == SOCK_DGRAM ? "UDP" : "TCP"; + kr_log_error("[system] bind to '%s@%d' (%s): %s\n", + str, port, stype, kr_strerror(ret)); } return ret == 0; } @@ -126,7 +137,7 @@ static bool net_listen_addrs(lua_State *L, int port, bool tls) lua_error_p(L, "bad type for address"); lua_pushnil(L); while (lua_next(L, -2)) { - if (!net_listen_addrs(L, port, tls)) + if (!net_listen_addrs(L, port, tls, kind)) return false; lua_pop(L, 1); } @@ -156,18 +167,47 @@ static int net_listen(lua_State *L) } int port = KR_DNS_PORT; - if (n > 1 && lua_isnumber(L, 2)) { - port = lua_tointeger(L, 2); + if (n > 1) { + if (lua_isnumber(L, 2)) { + port = lua_tointeger(L, 2); + } else + if (!lua_isnil(L, 2)) { + lua_error_p(L, "wrong type of second parameter (port number)"); + } } bool tls = (port == KR_DNS_TLS_PORT); - if (n > 2 && lua_istable(L, 3)) { + const char *kind = NULL; + if (n > 2) { + if (!lua_istable(L, 3)) + lua_error_p(L, "wrong type of third parameter (table expected)"); tls = table_get_flag(L, 3, "tls", tls); + + lua_getfield(L, 3, "kind"); + const char *k = lua_tostring(L, -1); + if (k && strcasecmp(k, "dns") == 0) { + tls = false; + } else + if (k && strcasecmp(k, "tls") == 0) { + tls = true; + } else + if (k) { + kind = k; + } + } + + /* Memory management of `kind` string is difficult due to longjmp etc. + * Pop will unreference the lua value, so we store it on C stack instead (!) */ + const int kind_alen = kind ? strlen(kind) + 1 : 1 /* 0 length isn't C standard */; + char kind_buf[kind_alen]; + if (kind) { + memcpy(kind_buf, kind, kind_alen); + kind = kind_buf; } /* Now focus on the first argument. */ - lua_pop(L, n - 1); - const bool res = net_listen_addrs(L, port, tls); + lua_settop(L, 1); + const bool res = net_listen_addrs(L, port, tls, kind); lua_pushboolean(L, res); return 1; } @@ -180,6 +220,8 @@ static int net_close(lua_State *L) if (n < 2) lua_error_p(L, "expected 'close(string addr, number port)'"); + /* FIXME: support different kind values */ + /* Open resolution context cache */ struct network *net = &engine_luaget(L)->net; const char *addr = lua_tostring(L, 1); @@ -922,6 +964,47 @@ static int net_bpf_clear(lua_State *L) lua_error_p(L, "BPF is not supported on this operating system"); } +static int net_register_endpoint_kind(lua_State *L) +{ + const int param_count = lua_gettop(L); + if (param_count != 1 && param_count != 2) + lua_error_p(L, "expected one or two parameters"); + if (!lua_isstring(L, 1)) { + lua_error_p(L, "incorrect kind '%s'", lua_tostring(L, 1)); + } + size_t kind_len; + const char *kind = lua_tolstring(L, 1, &kind_len); + struct network *net = &engine_luaget(L)->net; + + /* Unregistering */ + if (param_count == 1) { + void *val; + if (trie_del(net->endpoint_kinds, kind, kind_len, &val) == KNOT_EOK) { + const int fun_id = (char *)val - (char *)NULL; + luaL_unref(L, LUA_REGISTRYINDEX, fun_id); + return 0; + } + lua_error_p(L, "attempt to unregister unknown kind '%s'\n", kind); + } /* else */ + + /* Registering */ + assert(param_count == 2); + if (!lua_isfunction(L, 2)) { + lua_error_p(L, "second parameter: expected function but got %s\n", + lua_typename(L, lua_type(L, 2))); + } + const int fun_id = luaL_ref(L, LUA_REGISTRYINDEX); + /* ^^ The function is on top of the stack, incidentally. */ + void **pp = trie_get_ins(net->endpoint_kinds, kind, kind_len); + if (!pp) lua_error_maybe(L, kr_error(ENOMEM)); + if (*pp != NULL || !strcasecmp(kind, "dns") || !strcasecmp(kind, "tls")) + lua_error_p(L, "attempt to register known kind '%s'\n", kind); + *pp = (char *)NULL + fun_id; + /* We don't attempt to engage correspoinding endpoints now. + * That's the job for network_engage_endpoints() later. */ + return 0; +} + int kr_bindings_net(lua_State *L) { static const luaL_Reg lib[] = { @@ -944,6 +1027,7 @@ int kr_bindings_net(lua_State *L) { "tls_handshake_timeout", net_tls_handshake_timeout }, { "bpf_set", net_bpf_set }, { "bpf_clear", net_bpf_clear }, + { "register_endpoint_kind", net_register_endpoint_kind }, { NULL, NULL } }; register_lib(L, "net", lib); diff --git a/daemon/engine.c b/daemon/engine.c index 9ba6d2b63..99212e28a 100644 --- a/daemon/engine.c +++ b/daemon/engine.c @@ -680,10 +680,19 @@ void engine_deinit(struct engine *engine) if (engine == NULL) { return; } + /* Only close sockets and services; no need to clean up mempool. */ - /* Only close sockets and services, - * no need to clean up mempool. */ - network_deinit(&engine->net); + /* Network deinit is split up. We first need to stop listening, + * then we can unload modules during which we still want + * e.g. the endpoint kind registry to work (inside ->net), + * and this registry deinitization uses the lua state. */ + network_close_force(&engine->net); + for (size_t i = 0; i < engine->ipc_set.len; ++i) { + close(engine->ipc_set.at[i]); + } + for (size_t i = 0; i < engine->modules.len; ++i) { + engine_unload(engine, engine->modules.at[i]); + } kr_zonecut_deinit(&engine->resolver.root_hints); kr_cache_close(&engine->resolver.cache); @@ -692,18 +701,8 @@ void engine_deinit(struct engine *engine) lru_free(engine->resolver.cache_rep); lru_free(engine->resolver.cache_cookie); - /* Clear IPC pipes */ - for (size_t i = 0; i < engine->ipc_set.len; ++i) { - close(engine->ipc_set.at[i]); - } - - /* Unload modules and engine. */ - for (size_t i = 0; i < engine->modules.len; ++i) { - engine_unload(engine, engine->modules.at[i]); - } - if (engine->L) { - lua_close(engine->L); - } + network_deinit(&engine->net); + lua_close(engine->L); /* Free data structures */ array_clear(engine->modules); diff --git a/daemon/lua/kres-gen.lua b/daemon/lua/kres-gen.lua index 84db39ca0..afd7a7a87 100644 --- a/daemon/lua/kres-gen.lua +++ b/daemon/lua/kres-gen.lua @@ -348,6 +348,18 @@ int kr_cache_remove(struct kr_cache *, const knot_dname_t *, uint16_t); int kr_cache_remove_subtree(struct kr_cache *, const knot_dname_t *, _Bool, int); int kr_cache_commit(struct kr_cache *); uint32_t packet_ttl(const knot_pkt_t *, _Bool); +typedef struct { + int sock_type; + _Bool tls; + const char *kind; +} endpoint_flags_t; +struct endpoint { + void *handle; + int fd; + uint16_t port; + _Bool engaged; + endpoint_flags_t flags; +}; typedef struct { uint8_t bitmap[32]; uint8_t length; diff --git a/daemon/lua/kres-gen.sh b/daemon/lua/kres-gen.sh index 1365f1fe0..a5702c5d1 100755 --- a/daemon/lua/kres-gen.sh +++ b/daemon/lua/kres-gen.sh @@ -5,6 +5,7 @@ set -o pipefail -o errexit -o nounset cd "$(dirname ${0})" CDEFS="../../scripts/gen-cdefs.sh" LIBKRES="${MESON_BUILD_ROOT}/lib/libkres.so" +KRESD="${MESON_BUILD_ROOT}/daemon/kresd" # Write to kres-gen.lua instead of stdout mv kres-gen.lua{,.bak} ||: @@ -215,6 +216,11 @@ ${CDEFS} ${LIBKRES} functions <<-EOF packet_ttl EOF +## kresd daemon stuff, too +${CDEFS} ${KRESD} types <<-EOF + endpoint_flags_t +EOF +echo "struct endpoint" | ${CDEFS} ${KRESD} types | sed 's/uv_handle_t \*/void */' ## libzscanner API for ./zonefile.lua ${CDEFS} libzscanner types <<-EOF diff --git a/daemon/main.c b/daemon/main.c index 2c3bd3a43..8f0e84c7b 100644 --- a/daemon/main.c +++ b/daemon/main.c @@ -48,14 +48,18 @@ /* @internal Array of ip address shorthand. */ typedef array_t(char*) addr_array_t; +typedef struct { + int fd; + endpoint_flags_t flags; /**< .sock_type isn't meaningful here */ +} flagged_fd_t; +typedef array_t(flagged_fd_t) flagged_fd_array_t; + struct args { + addr_array_t addrs, addrs_tls; + flagged_fd_array_t fds; + int control_fd; int forks; - addr_array_t addr_set; - addr_array_t tls_set; - fd_array_t fd_set; - fd_array_t tls_fd_set; const char *config; - int control_fd; const char *rundir; bool interactive; bool quiet; @@ -365,8 +369,7 @@ static void help(int argc, char *argv[]) printf("\nParameters:\n" " -a, --addr=[addr] Server address (default: localhost@53).\n" " -t, --tls=[addr] Server address for TLS (default: off).\n" - " -S, --fd=[fd] Listen on given fd (handed out by supervisor).\n" - " -T, --tlsfd=[fd] Listen using TLS on given fd (handed out by supervisor).\n" + " -S, --fd=[fd:kind] Listen on given fd (handed out by supervisor, :kind is optional).\n" " -c, --config=[path] Config file path (relative to [rundir]) (default: config).\n" " -f, --forks=N Start N forks sharing the configuration.\n" " -q, --quiet No command prompt in interactive mode.\n" @@ -468,16 +471,23 @@ static void free_sd_socket_names(char **socket_names, int count) static void args_init(struct args *args) { memset(args, 0, sizeof(struct args)); + /* Zeroed arrays are OK. */ args->forks = 1; - array_init(args->addr_set); - array_init(args->tls_set); - array_init(args->fd_set); - array_init(args->tls_fd_set); args->control_fd = -1; args->interactive = true; args->quiet = false; } +/* Free pointed-to resources. */ +static void args_deinit(struct args *args) +{ + array_clear(args->addrs); + array_clear(args->addrs_tls); + for (int i = 0; i < args->fds.len; ++i) + free_const(args->fds.at[i].flags.kind); + array_clear(args->fds); +} + static long strtol_10(const char *s) { if (!s) abort(); @@ -497,7 +507,6 @@ static int parse_args(int argc, char **argv, struct args *args) {"addr", required_argument, 0, 'a'}, {"tls", required_argument, 0, 't'}, {"fd", required_argument, 0, 'S'}, - {"tlsfd", required_argument, 0, 'T'}, {"config", required_argument, 0, 'c'}, {"forks", required_argument, 0, 'f'}, {"verbose", no_argument, 0, 'v'}, @@ -506,20 +515,14 @@ static int parse_args(int argc, char **argv, struct args *args) {"help", no_argument, 0, 'h'}, {0, 0, 0, 0} }; - while ((c = getopt_long(argc, argv, "a:t:S:T:c:f:m:K:k:vqVh", opts, &li)) != -1) { + while ((c = getopt_long(argc, argv, "a:t:S:c:f:m:K:k:vqVh", opts, &li)) != -1) { switch (c) { case 'a': - array_push(args->addr_set, optarg); + array_push(args->addrs, optarg); break; case 't': - array_push(args->tls_set, optarg); - break; - case 'S': - array_push(args->fd_set, strtol_10(optarg)); - break; - case 'T': - array_push(args->tls_fd_set, strtol_10(optarg)); + array_push(args->addrs_tls, optarg); break; case 'c': args->config = optarg; @@ -552,6 +555,28 @@ static int parse_args(int argc, char **argv, struct args *args) default: help(argc, argv); return EXIT_FAILURE; + case 'S': + (void)0; + flagged_fd_t ffd = { 0 }; + char *endptr; + ffd.fd = strtol(optarg, &endptr, 10); + if (endptr != optarg && endptr[0] == '\0') { + /* Plain DNS */ + ffd.flags.tls = false; + } else if (endptr[0] == ':' && strcasecmp(endptr + 1, "tls") == 0) { + /* DoT */ + ffd.flags.tls = true; + /* We know what .sock_type should be but it wouldn't help. */ + } else if (endptr[0] == ':' && endptr[1] != '\0') { + /* Some other kind; no checks here. */ + ffd.flags.kind = strdup(endptr + 1); + } else { + kr_log_error("[system] incorrect value passed to '-S/--fd': %s\n", + optarg); + return EXIT_FAILURE; + } + array_push(args->fds, ffd); + break; } } if (optind < argc) { @@ -560,10 +585,10 @@ static int parse_args(int argc, char **argv, struct args *args) return -1; } -/** Just convert addresses to file-descriptors. +/** Just convert addresses to file-descriptors; clear *addrs on success. * @return zero or exit code for main() */ -static int bind_sockets(addr_array_t *addrs, bool tls, fd_array_t *fds) +static int bind_sockets(addr_array_t *addrs, bool tls, flagged_fd_array_t *fds) { for (size_t i = 0; i < addrs->len; ++i) { uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT; @@ -574,18 +599,19 @@ static int bind_sockets(addr_array_t *addrs, bool tls, fd_array_t *fds) sa = kr_straddr_socket(addr_str, port, NULL); if (!sa) ret = kr_error(EINVAL); /* could be ENOMEM but unlikely */ } + flagged_fd_t ffd = { .flags = { .tls = tls } }; if (ret == 0 && !tls) { - const int fd = io_bind(sa, SOCK_DGRAM); - if (fd < 0) - ret = fd; - else if (array_push(*fds, fd) < 0) + ffd.fd = io_bind(sa, SOCK_DGRAM); + if (ffd.fd < 0) + ret = ffd.fd; + else if (array_push(*fds, ffd) < 0) ret = kr_error(ENOMEM); } if (ret == 0) { /* common for TCP and TLS */ - const int fd = io_bind(sa, SOCK_STREAM); - if (fd < 0) - ret = fd; - else if (array_push(*fds, fd) < 0) + ffd.fd = io_bind(sa, SOCK_STREAM); + if (ffd.fd < 0) + ret = ffd.fd; + else if (array_push(*fds, ffd) < 0) ret = kr_error(ENOMEM); } free(sa); @@ -595,20 +621,28 @@ static int bind_sockets(addr_array_t *addrs, bool tls, fd_array_t *fds) return EXIT_FAILURE; } } + array_clear(*addrs); return kr_ok(); } -static int bind_fds(struct network *net, fd_array_t *fd_set, bool tls) { - int ret = 0; - for (size_t i = 0; i < fd_set->len; ++i) { - ret = network_listen_fd(net, fd_set->at[i], tls); +static int start_listening(struct network *net, flagged_fd_array_t *fds) { + int some_bad_ret = 0; + for (size_t i = 0; i < fds->len; ++i) { + flagged_fd_t *ffd = &fds->at[i]; + int ret = network_listen_fd(net, ffd->fd, ffd->flags); if (ret != 0) { - kr_log_error("[system] %slisten on fd=%d %s\n", - tls ? "TLS " : "", fd_set->at[i], kr_strerror(ret)); - break; + some_bad_ret = ret; + /* TODO: try logging address@port. It's not too important, + * because typical problems happen during binding already. + * (invalid address, permission denied) */ + kr_log_error("[system] listen on fd=%d: %s\n", + ffd->fd, kr_strerror(ret)); + /* Continue printing all of these before exiting. */ + } else { + ffd->flags.kind = NULL; /* ownership transferred */ } } - return ret; + return some_bad_ret; } int main(int argc, char **argv) @@ -618,34 +652,58 @@ int main(int argc, char **argv) int ret = parse_args(argc, argv, &args); if (ret >= 0) goto cleanup_args; - ret = bind_sockets(&args.addr_set, false, &args.fd_set); + ret = bind_sockets(&args.addrs, false, &args.fds); if (ret) goto cleanup_args; - ret = bind_sockets(&args.tls_set, true, &args.tls_fd_set); + ret = bind_sockets(&args.addrs_tls, true, &args.fds); if (ret) goto cleanup_args; #ifdef HAS_SYSTEMD /* Accept passed sockets from systemd supervisor. */ char **socket_names = NULL; int sd_nsocks = sd_listen_fds_with_names(0, &socket_names); + if (sd_nsocks < 0) { + kr_log_error("[system] failed passing sockets from systemd: %s\n", + kr_strerror(sd_nsocks)); + free_sd_socket_names(socket_names, sd_nsocks); + ret = EXIT_FAILURE; + goto cleanup_args; + } + if (sd_nsocks > 0 && args.forks != 1) { + kr_log_error("[system] when run under systemd-style supervision, " + "use single-process only (bad: --forks=%d).\n", args.forks); + free_sd_socket_names(socket_names, sd_nsocks); + ret = EXIT_FAILURE; + goto cleanup_args; + } for (int i = 0; i < sd_nsocks; ++i) { - int fd = SD_LISTEN_FDS_START + i; /* when run under systemd supervision, do not use interactive mode */ args.interactive = false; - if (args.forks != 1) { - kr_log_error("[system] when run under systemd-style supervision, " - "use single-process only (bad: --forks=%d).\n", args.forks); - free_sd_socket_names(socket_names, sd_nsocks); - return EXIT_FAILURE; - } - if (!strcasecmp("control",socket_names[i])) { - args.control_fd = fd; - } else if (!strcasecmp("tls",socket_names[i])) { - array_push(args.tls_fd_set, fd); + flagged_fd_t ffd = { .fd = SD_LISTEN_FDS_START + i }; + + if (!strcasecmp("control", socket_names[i])) { + if (args.control_fd != -1) { + kr_log_error("[system] multiple control sockets passed from systemd\n"); + ret = EXIT_FAILURE; + break; + } + args.control_fd = ffd.fd; + free(socket_names[i]); } else { - array_push(args.fd_set, fd); + if (!strcasecmp("dns", socket_names[i])) { + free(socket_names[i]); + } else if (!strcasecmp("tls", socket_names[i])) { + ffd.flags.tls = true; + free(socket_names[i]); + } else { + ffd.flags.kind = socket_names[i]; + } + array_push(args.fds, ffd); } + /* Either freed or passed ownership. */ + socket_names[i] = NULL; } free_sd_socket_names(socket_names, sd_nsocks); + if (ret) goto cleanup_args; #endif /* Switch to rundir. */ @@ -729,10 +787,8 @@ int main(int argc, char **argv) goto cleanup; } - /* Bind to passed fds and sockets*/ - if (bind_fds(&engine.net, &args.fd_set, false) != 0 || - bind_fds(&engine.net, &args.tls_fd_set, true) != 0 - ) { + /* Start listening, in the sense of network_listen_fd(). */ + if (start_listening(&engine.net, &args.fds) != 0) { ret = EXIT_FAILURE; goto cleanup; } @@ -760,6 +816,11 @@ int main(int argc, char **argv) goto cleanup; } + if (network_engage_endpoints(&engine.net)) { + ret = EXIT_FAILURE; + goto cleanup; + } + /* Run the event loop */ ret = run_worker(loop, &engine, &ipc_set, fork_id == 0, &args); @@ -771,10 +832,7 @@ cleanup:/* Cleanup. */ } mp_delete(pool.ctx); cleanup_args: - array_clear(args.addr_set); - array_clear(args.tls_set); - array_clear(args.fd_set); - array_clear(args.tls_fd_set); + args_deinit(&args); kr_crypto_cleanup(); return ret; } diff --git a/daemon/network.c b/daemon/network.c index 3e431d272..b7ac00988 100644 --- a/daemon/network.c +++ b/daemon/network.c @@ -16,6 +16,7 @@ #include #include +#include "daemon/bindings/impl.h" #include "daemon/network.h" #include "daemon/worker.h" #include "daemon/io.h" @@ -26,6 +27,7 @@ void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog) if (net != NULL) { net->loop = loop; net->endpoints = map_make(NULL); + net->endpoint_kinds = trie_create(NULL); net->tls_client_params = NULL; net->tls_session_ticket_ctx = /* unsync. random, by default */ tls_session_ticket_ctx_create(loop, NULL, 0); @@ -35,26 +37,136 @@ void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog) } } -static void close_handle(uv_handle_t *handle, bool force) +/** Notify the registered function about endpoint getting open. + * If log_port < 1, don't log it. */ +static int endpoint_open_lua_cb(struct network *net, struct endpoint *ep, + const char *log_addr, int log_port) { + const bool ok = ep->flags.kind && !ep->handle && !ep->engaged && ep->fd != -1; + if (!ok) { + assert(!EINVAL); + return kr_error(EINVAL); + } + /* First find callback in the endpoint registry. */ + struct worker_ctx *worker = net->loop->data; // LATER: the_worker + lua_State *L = worker->engine->L; + void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind, + strlen(ep->flags.kind)); + if (!pp && net->missing_kind_is_error) { + kr_log_error("warning: network socket kind '%s' not handled when opening '%s", + ep->flags.kind, log_addr); + if (log_port >= 0) + kr_log_error("#%d", log_port); + kr_log_error("'. Likely causes: typo or not loading 'http' module.\n"); + /* No hard error, for now. LATER: perhaps differentiate between + * explicit net.listen() calls and "just unused" systemd sockets. + return kr_error(ENOENT); + */ + } + if (!pp) return kr_ok(); + + /* Now execute the callback. */ + const int fun_id = (char *)*pp - (char *)NULL; + lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id); + lua_pushboolean(L, true /* open */); + lua_pushpointer(L, ep); + if (log_port < 0) { + lua_pushstring(L, log_addr); + } else { + lua_pushfstring(L, "%s#%d", log_addr, log_port); + } + if (lua_pcall(L, 3, 0, 0)) { + kr_log_error("error opening %s: %s\n", log_addr, lua_tostring(L, -1)); + return kr_error(ENOSYS); /* TODO: better value? */ + } + ep->engaged = true; + return kr_ok(); +} + +static int engage_endpoint_array(const char *key, void *endpoints, void *net) +{ + endpoint_array_t *eps = (endpoint_array_t *)endpoints; + for (int i = 0; i < eps->len; ++i) { + struct endpoint *ep = &eps->at[i]; + const bool match = !ep->engaged && ep->flags.kind; + if (!match) continue; + int ret = endpoint_open_lua_cb(net, ep, key, ep->port); + if (ret) return ret; + } + return 0; +} +int network_engage_endpoints(struct network *net) +{ + if (net->missing_kind_is_error) + return kr_ok(); /* maybe weird, but let's make it idempotent */ + net->missing_kind_is_error = true; + int ret = map_walk(&net->endpoints, engage_endpoint_array, net); + if (ret) { + net->missing_kind_is_error = false; /* avoid the same errors when closing */ + return ret; + } + return kr_ok(); +} + + +/** Notify the registered function about endpoint about to be closed. */ +static void endpoint_close_lua_cb(struct network *net, struct endpoint *ep) +{ + struct worker_ctx *worker = net->loop->data; // LATER: the_worker + lua_State *L = worker->engine->L; + void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind, + strlen(ep->flags.kind)); + if (!pp && net->missing_kind_is_error) { + kr_log_error("internal error: missing kind '%s' in endpoint registry\n", + ep->flags.kind); + return; + } + if (!pp) return; + + const int fun_id = (char *)*pp - (char *)NULL; + lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id); + lua_pushboolean(L, false /* close */); + lua_pushpointer(L, ep); + lua_pushstring(L, "FIXME:endpoint-identifier"); + if (lua_pcall(L, 3, 0, 0)) { + kr_log_error("failed to close FIXME:endpoint-identifier: %s\n", + lua_tostring(L, -1)); + } +} + +static void endpoint_close(struct network *net, struct endpoint *ep, bool force) +{ + assert(!ep->handle != !ep->flags.kind); + if (ep->flags.kind) { /* Special endpoint. */ + if (ep->engaged) { + endpoint_close_lua_cb(net, ep); + } + if (ep->fd > 0) { + close(ep->fd); /* nothing to do with errors */ + } + free_const(ep->flags.kind); + return; + } + if (force) { /* Force close if event loop isn't running. */ - uv_os_fd_t fd = 0; - if (uv_fileno(handle, &fd) == 0) { - close(fd); + if (ep->fd >= 0) { + close(ep->fd); + } + if (ep->handle) { + ep->handle->loop = NULL; + io_free(ep->handle); } - handle->loop = NULL; - io_free(handle); } else { /* Asynchronous close */ - uv_close(handle, io_free); + uv_close(ep->handle, io_free); } } /** Endpoint visitor (see @file map.h) */ -static int close_key(const char *key, void *val, void *ext) +static int close_key(const char *key, void *val, void *net) { endpoint_array_t *ep_array = val; for (int i = 0; i < ep_array->len; ++i) { - close_handle(ep_array->at[i].handle, true); + endpoint_close(net, &ep_array->at[i], true); } return 0; } @@ -67,12 +179,30 @@ static int free_key(const char *key, void *val, void *ext) return kr_ok(); } -void network_deinit(struct network *net) +int kind_unregister(trie_val_t *tv, void *L) +{ + int fun_id = (char *)*tv - (char *)NULL; + luaL_unref(L, LUA_REGISTRYINDEX, fun_id); + return 0; +} + +void network_close_force(struct network *net) { if (net != NULL) { - map_walk(&net->endpoints, close_key, 0); + map_walk(&net->endpoints, close_key, net); map_walk(&net->endpoints, free_key, 0); map_clear(&net->endpoints); + } +} + +void network_deinit(struct network *net) +{ + if (net != NULL) { + network_close_force(net); + struct worker_ctx *worker = net->loop->data; // LATER: the_worker + trie_apply(net->endpoint_kinds, kind_unregister, worker->engine->L); + trie_free(net->endpoint_kinds); + tls_credentials_free(net->tls_credentials); tls_client_params_free(net->tls_client_params); tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx); @@ -108,7 +238,8 @@ static int insert_endpoint(struct network *net, const char *addr, struct endpoin /** Open endpoint protocols. ep->flags were pre-set. */ static int open_endpoint(struct network *net, struct endpoint *ep, - const struct sockaddr *sa, int fd) + const struct sockaddr *sa, int fd, + const char *log_addr, uint16_t log_port) { if ((sa != NULL) == (fd != -1)) { assert(!EINVAL); @@ -122,6 +253,14 @@ static int open_endpoint(struct network *net, struct endpoint *ep, fd = io_bind(sa, ep->flags.sock_type); if (fd < 0) return fd; } + ep->fd = fd; + if (ep->flags.kind) { + /* This EP isn't to be managed internally after binding. */ + return endpoint_open_lua_cb(net, ep, log_addr, log_port); + } else { + ep->engaged = true; + /* .engaged seems not really meaningful with .kind == NULL, but... */ + } if (ep->flags.sock_type == SOCK_DGRAM) { if (ep->flags.tls) { @@ -167,7 +306,8 @@ static struct endpoint * endpoint_get(struct network *net, const char *addr, return NULL; } -/** \note pass either sa != NULL xor fd != -1 */ +/** \note pass either sa != NULL xor fd != -1; + * \note ownership of flags.* is taken on success. */ static int create_endpoint(struct network *net, const char *addr_str, uint16_t port, endpoint_flags_t flags, const struct sockaddr *sa, int fd) @@ -178,27 +318,26 @@ static int create_endpoint(struct network *net, const char *addr_str, .port = port, .flags = flags, }; - int ret = open_endpoint(net, &ep, sa, fd); + int ret = open_endpoint(net, &ep, sa, fd, addr_str, port); if (ret == 0) { ret = insert_endpoint(net, addr_str, &ep); } if (ret != 0 && ep.handle) { - close_handle(ep.handle, false); + endpoint_close(net, &ep, false); } return ret; } -int network_listen_fd(struct network *net, int fd, bool use_tls) +int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags) { /* Extract fd's socket type. */ - endpoint_flags_t flags = { .tls = use_tls }; socklen_t len = sizeof(flags.sock_type); int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len); if (ret != 0) { return kr_error(errno); } - if (flags.sock_type == SOCK_DGRAM && use_tls) { - assert(!EINVAL); + if (flags.sock_type == SOCK_DGRAM && !flags.kind && flags.tls) { + assert(!EINVAL); /* Perhaps DTLS some day. */ return kr_error(EINVAL); } if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) { @@ -267,7 +406,7 @@ int network_close(struct network *net, const char *addr, uint16_t port, while (i < ep_array->len) { struct endpoint *ep = &ep_array->at[i]; if (endpoint_flags_eq(flags, ep->flags)) { - close_handle(ep->handle, false); + endpoint_close(net, ep, false); array_del(*ep_array, i); matched = true; /* do not advance i */ diff --git a/daemon/network.h b/daemon/network.h index 9e8df520e..faa3670c0 100644 --- a/daemon/network.h +++ b/daemon/network.h @@ -20,6 +20,7 @@ #include "lib/generic/array.h" #include "lib/generic/map.h" +#include "lib/generic/trie.h" #include #include @@ -27,22 +28,31 @@ struct engine; -/** Ways to listen for DNS on a port. */ +/** Ways to listen on a socket. */ typedef struct { int sock_type; /**< SOCK_DGRAM or SOCK_STREAM */ - bool tls; /**< only used together with .tcp */ + bool tls; /**< only used together with .tcp; TODO: meaningful if kind != NULL? */ + const char *kind; /**< tag for other types than the three usual */ } endpoint_flags_t; static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2) { - /* memcmp() would typically work, but there's no guarantee. */ - return f1.sock_type == f2.sock_type && f1.tls == f2.tls; + if (f1.sock_type != f2.sock_type) + return false; + if (f1.kind && f2.kind) + return strcasecmp(f1.kind, f2.kind); + else + return f1.tls == f2.tls && f1.kind == f2.kind; } -/** Wrapper for a single socket to listen on. */ +/** Wrapper for a single socket to listen on. + * There are two types: normal have handle, special have flags.kind (and never both). + */ struct endpoint { - uv_handle_t *handle; /** uv_udp_t or uv_tcp_t */ + uv_handle_t *handle; /**< uv_udp_t or uv_tcp_t */ + int fd; uint16_t port; + bool engaged; /**< to some module or internally */ endpoint_flags_t flags; }; @@ -63,6 +73,13 @@ struct network { * TODO: trie_t, keyed on *binary* address-port pair. */ map_t endpoints; + /** Registry of callbacks for special endpoint kinds (for opening/closing). + * Map: kind (lowercased) -> lua function ID converted to void * + * The ID is the usual: raw int index in the LUA_REGISTRYINDEX table. */ + trie_t *endpoint_kinds; + /** See network_engage_endpoints() */ + bool missing_kind_is_error; + struct tls_credentials *tls_credentials; tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */ struct tls_session_ticket_ctx *tls_session_ticket_ctx; @@ -76,18 +93,30 @@ void network_deinit(struct network *net); /** Start listenting on addr#port with flags. * \note if we did listen on that combination already, * nothing is done and kr_ok() is returned. - * \note there's no short-hand to listen both on UDP and TCP. */ + * \note there's no short-hand to listen both on UDP and TCP. + * \note ownership of flags.* is taken on success. TODO: non-success? + */ int network_listen(struct network *net, const char *addr, uint16_t port, endpoint_flags_t flags); -/** Start listenting on an open file-descriptor. */ -int network_listen_fd(struct network *net, int fd, bool use_tls); +/** Start listenting on an open file-descriptor. + * \note flags.sock_type isn't meaningful here. + * \note ownership of flags.* is taken on success. + */ +int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags); /** Stop listening on all addr#port with equal flags. * \return kr_error(ENOENT) if nothing matched. */ int network_close(struct network *net, const char *addr, uint16_t port, endpoint_flags_t flags); +/** Close all endpoints immediately (no waiting for UV loop). */ +void network_close_force(struct network *net); + +/** Enforce that all endpoints are registered from now on. + * This only does anything with struct endpoint::flags.kind != NULL. */ +int network_engage_endpoints(struct network *net); + int network_set_tls_cert(struct network *net, const char *cert); int network_set_tls_key(struct network *net, const char *key); void network_new_hostname(struct network *net, struct engine *engine);