]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: support for more endpoint kinds
authorVladimír Čunát <vladimir.cunat@nic.cz>
Sat, 13 Apr 2019 08:25:20 +0000 (10:25 +0200)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Thu, 18 Apr 2019 07:54:16 +0000 (09:54 +0200)
.gitlab-ci.yml
daemon/bindings/net.c
daemon/engine.c
daemon/lua/kres-gen.lua
daemon/lua/kres-gen.sh
daemon/main.c
daemon/network.c
daemon/network.h

index a2cc0f946aae537f3a31fa4ce93c935e0729157d..511b81289d55a27b1d7399584445f34e90a5e460 100644 (file)
@@ -87,7 +87,7 @@ kres-gen:
     - docker
   script:
     - meson build_ci_lib --prefix=$PREFIX
-    - ninja -C build_ci_lib lib/libkres.so.${LIBKRES_ABI}
+    - ninja -C build_ci_lib daemon/kresd
     - ninja -C build_ci_lib kres-gen
     - git diff --quiet || (git diff; exit 1)
 # }}}
index 30ba7971d72e3cbbe4b41ed933ed206ba4a4cf9d..fc7f23a34630ec89f09a3bef47b27fb99a30ec20 100644 (file)
@@ -66,7 +66,11 @@ static int net_list_add(const char *key, void *val, void *ext)
                lua_setfield(L, -2, "transport");
 
                lua_newtable(L);  // "application" table
-               lua_pushliteral(L, "dns");
+               if (ep->flags.kind) {
+                       lua_pushstring(L, ep->flags.kind);
+               } else {
+                       lua_pushliteral(L, "dns");
+               }
                lua_setfield(L, -2, "protocol");
                lua_setfield(L, -2, "application");
 
@@ -89,8 +93,9 @@ static int net_list(lua_State *L)
 }
 
 /** Listen on an address list represented by the top of lua stack.
+ * \note kind ownership is not transferred
  * \return success */
-static bool net_listen_addrs(lua_State *L, int port, bool tls)
+static bool net_listen_addrs(lua_State *L, int port, bool tls, const char *kind)
 {
        /* Case: table with 'addr' field; only follow that field directly. */
        lua_getfield(L, -1, "addr");
@@ -104,19 +109,25 @@ static bool net_listen_addrs(lua_State *L, int port, bool tls)
        const char *str = lua_tostring(L, -1);
        if (str != NULL) {
                struct engine *engine = engine_luaget(L);
-               endpoint_flags_t flags = { .tls = tls };
                int ret = 0;
-               if (!tls) {
+               endpoint_flags_t flags = { .tls = tls };
+               if (!kind && !flags.tls) { /* normal UDP */
                        flags.sock_type = SOCK_DGRAM;
                        ret = network_listen(&engine->net, str, port, flags);
                }
-               if (ret == 0) { /* common for TCP and TLS */
+               if (!kind && ret == 0) { /* common for normal TCP and TLS */
                        flags.sock_type = SOCK_STREAM;
                        ret = network_listen(&engine->net, str, port, flags);
                }
+               if (kind) {
+                       flags.kind = strdup(kind);
+                       flags.sock_type = SOCK_STREAM; /* TODO: allow to override this? */
+                       ret = network_listen(&engine->net, str, port, flags);
+               }
                if (ret != 0) {
-                       kr_log_info("[system] bind to '%s@%d' %s\n",
-                                       str, port, kr_strerror(ret));
+                       const char *stype = flags.sock_type == SOCK_DGRAM ? "UDP" : "TCP";
+                       kr_log_error("[system] bind to '%s@%d' (%s): %s\n",
+                                       str, port, stype, kr_strerror(ret));
                }
                return ret == 0;
        }
@@ -126,7 +137,7 @@ static bool net_listen_addrs(lua_State *L, int port, bool tls)
                lua_error_p(L, "bad type for address");
        lua_pushnil(L);
        while (lua_next(L, -2)) {
-               if (!net_listen_addrs(L, port, tls))
+               if (!net_listen_addrs(L, port, tls, kind))
                        return false;
                lua_pop(L, 1);
        }
@@ -156,18 +167,47 @@ static int net_listen(lua_State *L)
        }
 
        int port = KR_DNS_PORT;
-       if (n > 1 && lua_isnumber(L, 2)) {
-               port = lua_tointeger(L, 2);
+       if (n > 1) {
+               if (lua_isnumber(L, 2)) {
+                       port = lua_tointeger(L, 2);
+               } else
+               if (!lua_isnil(L, 2)) {
+                       lua_error_p(L, "wrong type of second parameter (port number)");
+               }
        }
 
        bool tls = (port == KR_DNS_TLS_PORT);
-       if (n > 2 && lua_istable(L, 3)) {
+       const char *kind = NULL;
+       if (n > 2) {
+               if (!lua_istable(L, 3))
+                       lua_error_p(L, "wrong type of third parameter (table expected)");
                tls = table_get_flag(L, 3, "tls", tls);
+
+               lua_getfield(L, 3, "kind");
+               const char *k = lua_tostring(L, -1);
+               if (k && strcasecmp(k, "dns") == 0) {
+                       tls = false;
+               } else
+               if (k && strcasecmp(k, "tls") == 0) {
+                       tls = true;
+               } else
+               if (k) {
+                       kind = k;
+               }
+       }
+
+       /* Memory management of `kind` string is difficult due to longjmp etc.
+        * Pop will unreference the lua value, so we store it on C stack instead (!) */
+       const int kind_alen = kind ? strlen(kind) + 1 : 1 /* 0 length isn't C standard */;
+       char kind_buf[kind_alen];
+       if (kind) {
+               memcpy(kind_buf, kind, kind_alen);
+               kind = kind_buf;
        }
 
        /* Now focus on the first argument. */
-       lua_pop(L, n - 1);
-       const bool res = net_listen_addrs(L, port, tls);
+       lua_settop(L, 1);
+       const bool res = net_listen_addrs(L, port, tls, kind);
        lua_pushboolean(L, res);
        return 1;
 }
@@ -180,6 +220,8 @@ static int net_close(lua_State *L)
        if (n < 2)
                lua_error_p(L, "expected 'close(string addr, number port)'");
 
+       /* FIXME: support different kind values */
+
        /* Open resolution context cache */
        struct network *net = &engine_luaget(L)->net;
        const char *addr = lua_tostring(L, 1);
@@ -922,6 +964,47 @@ static int net_bpf_clear(lua_State *L)
        lua_error_p(L, "BPF is not supported on this operating system");
 }
 
+static int net_register_endpoint_kind(lua_State *L)
+{
+       const int param_count = lua_gettop(L);
+       if (param_count != 1 && param_count != 2)
+               lua_error_p(L, "expected one or two parameters");
+       if (!lua_isstring(L, 1)) {
+               lua_error_p(L, "incorrect kind '%s'", lua_tostring(L, 1));
+       }
+       size_t kind_len;
+       const char *kind = lua_tolstring(L, 1, &kind_len);
+       struct network *net = &engine_luaget(L)->net;
+
+       /* Unregistering */
+       if (param_count == 1) {
+               void *val;
+               if (trie_del(net->endpoint_kinds, kind, kind_len, &val) == KNOT_EOK) {
+                       const int fun_id = (char *)val - (char *)NULL;
+                       luaL_unref(L, LUA_REGISTRYINDEX, fun_id);
+                       return 0;
+               }
+               lua_error_p(L, "attempt to unregister unknown kind '%s'\n", kind);
+       } /* else */
+
+       /* Registering */
+       assert(param_count == 2);
+       if (!lua_isfunction(L, 2)) {
+               lua_error_p(L, "second parameter: expected function but got %s\n",
+                               lua_typename(L, lua_type(L, 2)));
+       }
+       const int fun_id = luaL_ref(L, LUA_REGISTRYINDEX);
+               /* ^^ The function is on top of the stack, incidentally. */
+       void **pp = trie_get_ins(net->endpoint_kinds, kind, kind_len);
+       if (!pp) lua_error_maybe(L, kr_error(ENOMEM));
+       if (*pp != NULL || !strcasecmp(kind, "dns") || !strcasecmp(kind, "tls"))
+               lua_error_p(L, "attempt to register known kind '%s'\n", kind);
+       *pp = (char *)NULL + fun_id;
+       /* We don't attempt to engage correspoinding endpoints now.
+        * That's the job for network_engage_endpoints() later. */
+       return 0;
+}
+
 int kr_bindings_net(lua_State *L)
 {
        static const luaL_Reg lib[] = {
@@ -944,6 +1027,7 @@ int kr_bindings_net(lua_State *L)
                { "tls_handshake_timeout",  net_tls_handshake_timeout },
                { "bpf_set",      net_bpf_set },
                { "bpf_clear",    net_bpf_clear },
+               { "register_endpoint_kind", net_register_endpoint_kind },
                { NULL, NULL }
        };
        register_lib(L, "net", lib);
index 9ba6d2b63691aa0c1fb2ea0cee09eaeac147f283..99212e28a8d4aea983b6c834c5d750a5a1aa53c8 100644 (file)
@@ -680,10 +680,19 @@ void engine_deinit(struct engine *engine)
        if (engine == NULL) {
                return;
        }
+       /* Only close sockets and services; no need to clean up mempool. */
 
-       /* Only close sockets and services,
-        * no need to clean up mempool. */
-       network_deinit(&engine->net);
+       /* Network deinit is split up.  We first need to stop listening,
+        * then we can unload modules during which we still want
+        * e.g. the endpoint kind registry to work (inside ->net),
+        * and this registry deinitization uses the lua state. */
+       network_close_force(&engine->net);
+       for (size_t i = 0; i < engine->ipc_set.len; ++i) {
+               close(engine->ipc_set.at[i]);
+       }
+       for (size_t i = 0; i < engine->modules.len; ++i) {
+               engine_unload(engine, engine->modules.at[i]);
+       }
        kr_zonecut_deinit(&engine->resolver.root_hints);
        kr_cache_close(&engine->resolver.cache);
 
@@ -692,18 +701,8 @@ void engine_deinit(struct engine *engine)
        lru_free(engine->resolver.cache_rep);
        lru_free(engine->resolver.cache_cookie);
 
-       /* Clear IPC pipes */
-       for (size_t i = 0; i < engine->ipc_set.len; ++i) {
-               close(engine->ipc_set.at[i]);
-       }
-
-       /* Unload modules and engine. */
-       for (size_t i = 0; i < engine->modules.len; ++i) {
-               engine_unload(engine, engine->modules.at[i]);
-       }
-       if (engine->L) {
-               lua_close(engine->L);
-       }
+       network_deinit(&engine->net);
+       lua_close(engine->L);
 
        /* Free data structures */
        array_clear(engine->modules);
index 84db39ca02a1dd0b0e06b6009feb917ab68abe34..afd7a7a87c4be440ae63c8bb36c68ad9e6f0147f 100644 (file)
@@ -348,6 +348,18 @@ int kr_cache_remove(struct kr_cache *, const knot_dname_t *, uint16_t);
 int kr_cache_remove_subtree(struct kr_cache *, const knot_dname_t *, _Bool, int);
 int kr_cache_commit(struct kr_cache *);
 uint32_t packet_ttl(const knot_pkt_t *, _Bool);
+typedef struct {
+       int sock_type;
+       _Bool tls;
+       const char *kind;
+} endpoint_flags_t;
+struct endpoint {
+       void *handle;
+       int fd;
+       uint16_t port;
+       _Bool engaged;
+       endpoint_flags_t flags;
+};
 typedef struct {
        uint8_t bitmap[32];
        uint8_t length;
index 1365f1fe06007248f6eb828e330dd3152993c631..a5702c5d156634e488d60b52f38adbd6fd1377a7 100755 (executable)
@@ -5,6 +5,7 @@ set -o pipefail -o errexit -o nounset
 cd "$(dirname ${0})"
 CDEFS="../../scripts/gen-cdefs.sh"
 LIBKRES="${MESON_BUILD_ROOT}/lib/libkres.so"
+KRESD="${MESON_BUILD_ROOT}/daemon/kresd"
 
 # Write to kres-gen.lua instead of stdout
 mv kres-gen.lua{,.bak} ||:
@@ -215,6 +216,11 @@ ${CDEFS} ${LIBKRES} functions <<-EOF
        packet_ttl
 EOF
 
+## kresd daemon stuff, too
+${CDEFS} ${KRESD} types <<-EOF
+       endpoint_flags_t
+EOF
+echo "struct endpoint" | ${CDEFS} ${KRESD} types | sed 's/uv_handle_t \*/void */'
 
 ## libzscanner API for ./zonefile.lua
 ${CDEFS} libzscanner types <<-EOF
index 2c3bd3a4330a555b3d533c1ce717a9981ffcc888..8f0e84c7bc1171be013adb3a81cb61bbf0e7c4e0 100644 (file)
 /* @internal Array of ip address shorthand. */
 typedef array_t(char*) addr_array_t;
 
+typedef struct {
+       int fd;
+       endpoint_flags_t flags; /**< .sock_type isn't meaningful here */
+} flagged_fd_t;
+typedef array_t(flagged_fd_t) flagged_fd_array_t;
+
 struct args {
+       addr_array_t addrs, addrs_tls;
+       flagged_fd_array_t fds;
+       int control_fd;
        int forks;
-       addr_array_t addr_set;
-       addr_array_t tls_set;
-       fd_array_t fd_set;
-       fd_array_t tls_fd_set;
        const char *config;
-       int control_fd;
        const char *rundir;
        bool interactive;
        bool quiet;
@@ -365,8 +369,7 @@ static void help(int argc, char *argv[])
        printf("\nParameters:\n"
               " -a, --addr=[addr]      Server address (default: localhost@53).\n"
               " -t, --tls=[addr]       Server address for TLS (default: off).\n"
-              " -S, --fd=[fd]          Listen on given fd (handed out by supervisor).\n"
-              " -T, --tlsfd=[fd]       Listen using TLS on given fd (handed out by supervisor).\n"
+              " -S, --fd=[fd:kind]     Listen on given fd (handed out by supervisor, :kind is optional).\n"
               " -c, --config=[path]    Config file path (relative to [rundir]) (default: config).\n"
               " -f, --forks=N          Start N forks sharing the configuration.\n"
               " -q, --quiet            No command prompt in interactive mode.\n"
@@ -468,16 +471,23 @@ static void free_sd_socket_names(char **socket_names, int count)
 static void args_init(struct args *args)
 {
        memset(args, 0, sizeof(struct args));
+       /* Zeroed arrays are OK. */
        args->forks = 1;
-       array_init(args->addr_set);
-       array_init(args->tls_set);
-       array_init(args->fd_set);
-       array_init(args->tls_fd_set);
        args->control_fd = -1;
        args->interactive = true;
        args->quiet = false;
 }
 
+/* Free pointed-to resources. */
+static void args_deinit(struct args *args)
+{
+       array_clear(args->addrs);
+       array_clear(args->addrs_tls);
+       for (int i = 0; i < args->fds.len; ++i)
+               free_const(args->fds.at[i].flags.kind);
+       array_clear(args->fds);
+}
+
 static long strtol_10(const char *s)
 {
        if (!s) abort();
@@ -497,7 +507,6 @@ static int parse_args(int argc, char **argv, struct args *args)
                {"addr",       required_argument, 0, 'a'},
                {"tls",        required_argument, 0, 't'},
                {"fd",         required_argument, 0, 'S'},
-               {"tlsfd",      required_argument, 0, 'T'},
                {"config",     required_argument, 0, 'c'},
                {"forks",      required_argument, 0, 'f'},
                {"verbose",          no_argument, 0, 'v'},
@@ -506,20 +515,14 @@ static int parse_args(int argc, char **argv, struct args *args)
                {"help",             no_argument, 0, 'h'},
                {0, 0, 0, 0}
        };
-       while ((c = getopt_long(argc, argv, "a:t:S:T:c:f:m:K:k:vqVh", opts, &li)) != -1) {
+       while ((c = getopt_long(argc, argv, "a:t:S:c:f:m:K:k:vqVh", opts, &li)) != -1) {
                switch (c)
                {
                case 'a':
-                       array_push(args->addr_set, optarg);
+                       array_push(args->addrs, optarg);
                        break;
                case 't':
-                       array_push(args->tls_set, optarg);
-                       break;
-               case 'S':
-                       array_push(args->fd_set, strtol_10(optarg));
-                       break;
-               case 'T':
-                       array_push(args->tls_fd_set, strtol_10(optarg));
+                       array_push(args->addrs_tls, optarg);
                        break;
                case 'c':
                        args->config = optarg;
@@ -552,6 +555,28 @@ static int parse_args(int argc, char **argv, struct args *args)
                default:
                        help(argc, argv);
                        return EXIT_FAILURE;
+               case 'S':
+                       (void)0;
+                       flagged_fd_t ffd = { 0 };
+                       char *endptr;
+                       ffd.fd = strtol(optarg, &endptr, 10);
+                       if (endptr != optarg && endptr[0] == '\0') {
+                               /* Plain DNS */
+                               ffd.flags.tls = false;
+                       } else if (endptr[0] == ':' && strcasecmp(endptr + 1, "tls") == 0) {
+                               /* DoT */
+                               ffd.flags.tls = true;
+                               /* We know what .sock_type should be but it wouldn't help. */
+                       } else if (endptr[0] == ':' && endptr[1] != '\0') {
+                               /* Some other kind; no checks here. */
+                               ffd.flags.kind = strdup(endptr + 1);
+                       } else {
+                               kr_log_error("[system] incorrect value passed to '-S/--fd': %s\n",
+                                               optarg);
+                               return EXIT_FAILURE;
+                       }
+                       array_push(args->fds, ffd);
+                       break;
                }
        }
        if (optind < argc) {
@@ -560,10 +585,10 @@ static int parse_args(int argc, char **argv, struct args *args)
        return -1;
 }
 
-/** Just convert addresses to file-descriptors.
+/** Just convert addresses to file-descriptors; clear *addrs on success.
  * @return zero or exit code for main()
  */
-static int bind_sockets(addr_array_t *addrs, bool tls, fd_array_t *fds)
+static int bind_sockets(addr_array_t *addrs, bool tls, flagged_fd_array_t *fds)
 {
        for (size_t i = 0; i < addrs->len; ++i) {
                uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT;
@@ -574,18 +599,19 @@ static int bind_sockets(addr_array_t *addrs, bool tls, fd_array_t *fds)
                        sa = kr_straddr_socket(addr_str, port, NULL);
                        if (!sa) ret = kr_error(EINVAL); /* could be ENOMEM but unlikely */
                }
+               flagged_fd_t ffd = { .flags = { .tls = tls } };
                if (ret == 0 && !tls) {
-                       const int fd = io_bind(sa, SOCK_DGRAM);
-                       if (fd < 0)
-                               ret = fd;
-                       else if (array_push(*fds, fd) < 0)
+                       ffd.fd = io_bind(sa, SOCK_DGRAM);
+                       if (ffd.fd < 0)
+                               ret = ffd.fd;
+                       else if (array_push(*fds, ffd) < 0)
                                ret = kr_error(ENOMEM);
                }
                if (ret == 0) { /* common for TCP and TLS */
-                       const int fd = io_bind(sa, SOCK_STREAM);
-                       if (fd < 0)
-                               ret = fd;
-                       else if (array_push(*fds, fd) < 0)
+                       ffd.fd = io_bind(sa, SOCK_STREAM);
+                       if (ffd.fd < 0)
+                               ret = ffd.fd;
+                       else if (array_push(*fds, ffd) < 0)
                                ret = kr_error(ENOMEM);
                }
                free(sa);
@@ -595,20 +621,28 @@ static int bind_sockets(addr_array_t *addrs, bool tls, fd_array_t *fds)
                        return EXIT_FAILURE;
                }
        }
+       array_clear(*addrs);
        return kr_ok();
 }
 
-static int bind_fds(struct network *net, fd_array_t *fd_set, bool tls) {
-       int ret = 0;
-       for (size_t i = 0; i < fd_set->len; ++i) {
-               ret = network_listen_fd(net, fd_set->at[i], tls);
+static int start_listening(struct network *net, flagged_fd_array_t *fds) {
+       int some_bad_ret = 0;
+       for (size_t i = 0; i < fds->len; ++i) {
+               flagged_fd_t *ffd = &fds->at[i];
+               int ret = network_listen_fd(net, ffd->fd, ffd->flags);
                if (ret != 0) {
-                       kr_log_error("[system] %slisten on fd=%d %s\n",
-                                tls ? "TLS " : "", fd_set->at[i], kr_strerror(ret));
-                       break;
+                       some_bad_ret = ret;
+                       /* TODO: try logging address@port.  It's not too important,
+                        * because typical problems happen during binding already.
+                        * (invalid address, permission denied) */
+                       kr_log_error("[system] listen on fd=%d: %s\n",
+                                       ffd->fd, kr_strerror(ret));
+                       /* Continue printing all of these before exiting. */
+               } else {
+                       ffd->flags.kind = NULL; /* ownership transferred */
                }
        }
-       return ret;
+       return some_bad_ret;
 }
 
 int main(int argc, char **argv)
@@ -618,34 +652,58 @@ int main(int argc, char **argv)
        int ret = parse_args(argc, argv, &args);
        if (ret >= 0) goto cleanup_args;
 
-       ret = bind_sockets(&args.addr_set, false, &args.fd_set);
+       ret = bind_sockets(&args.addrs, false, &args.fds);
        if (ret) goto cleanup_args;
-       ret = bind_sockets(&args.tls_set, true, &args.tls_fd_set);
+       ret = bind_sockets(&args.addrs_tls, true, &args.fds);
        if (ret) goto cleanup_args;
 
 #ifdef HAS_SYSTEMD
        /* Accept passed sockets from systemd supervisor. */
        char **socket_names = NULL;
        int sd_nsocks = sd_listen_fds_with_names(0, &socket_names);
+       if (sd_nsocks < 0) {
+               kr_log_error("[system] failed passing sockets from systemd: %s\n",
+                            kr_strerror(sd_nsocks));
+               free_sd_socket_names(socket_names, sd_nsocks);
+               ret = EXIT_FAILURE;
+               goto cleanup_args;
+       }
+       if (sd_nsocks > 0 && args.forks != 1) {
+               kr_log_error("[system] when run under systemd-style supervision, "
+                            "use single-process only (bad: --forks=%d).\n", args.forks);
+               free_sd_socket_names(socket_names, sd_nsocks);
+               ret = EXIT_FAILURE;
+               goto cleanup_args;
+       }
        for (int i = 0; i < sd_nsocks; ++i) {
-               int fd = SD_LISTEN_FDS_START + i;
                /* when run under systemd supervision, do not use interactive mode */
                args.interactive = false;
-               if (args.forks != 1) {
-                       kr_log_error("[system] when run under systemd-style supervision, "
-                                    "use single-process only (bad: --forks=%d).\n", args.forks);
-                       free_sd_socket_names(socket_names, sd_nsocks);
-                       return EXIT_FAILURE;
-               }
-               if (!strcasecmp("control",socket_names[i])) {
-                       args.control_fd = fd;
-               } else if (!strcasecmp("tls",socket_names[i])) {
-                       array_push(args.tls_fd_set, fd);
+               flagged_fd_t ffd = { .fd = SD_LISTEN_FDS_START + i };
+
+               if (!strcasecmp("control", socket_names[i])) {
+                       if (args.control_fd != -1) {
+                               kr_log_error("[system] multiple control sockets passed from systemd\n");
+                               ret = EXIT_FAILURE;
+                               break;
+                       }
+                       args.control_fd = ffd.fd;
+                       free(socket_names[i]);
                } else {
-                       array_push(args.fd_set, fd);
+                       if (!strcasecmp("dns", socket_names[i])) {
+                               free(socket_names[i]);
+                       } else if (!strcasecmp("tls", socket_names[i])) {
+                               ffd.flags.tls = true;
+                               free(socket_names[i]);
+                       } else {
+                               ffd.flags.kind = socket_names[i];
+                       }
+                       array_push(args.fds, ffd);
                }
+               /* Either freed or passed ownership. */
+               socket_names[i] = NULL;
        }
        free_sd_socket_names(socket_names, sd_nsocks);
+       if (ret) goto cleanup_args;
 #endif
 
        /* Switch to rundir. */
@@ -729,10 +787,8 @@ int main(int argc, char **argv)
                goto cleanup;
        }
 
-       /* Bind to passed fds and sockets*/
-       if (bind_fds(&engine.net, &args.fd_set, false) != 0 ||
-           bind_fds(&engine.net, &args.tls_fd_set, true) != 0
-       ) {
+       /* Start listening, in the sense of network_listen_fd(). */
+       if (start_listening(&engine.net, &args.fds) != 0) {
                ret = EXIT_FAILURE;
                goto cleanup;
        }
@@ -760,6 +816,11 @@ int main(int argc, char **argv)
                goto cleanup;
        }
 
+       if (network_engage_endpoints(&engine.net)) {
+               ret = EXIT_FAILURE;
+               goto cleanup;
+       }
+
        /* Run the event loop */
        ret = run_worker(loop, &engine, &ipc_set, fork_id == 0, &args);
 
@@ -771,10 +832,7 @@ cleanup:/* Cleanup. */
        }
        mp_delete(pool.ctx);
 cleanup_args:
-       array_clear(args.addr_set);
-       array_clear(args.tls_set);
-       array_clear(args.fd_set);
-       array_clear(args.tls_fd_set);
+       args_deinit(&args);
        kr_crypto_cleanup();
        return ret;
 }
index 3e431d272f1e647c0797d34a6578b2249d84214b..b7ac00988a0f8c323677718a5c1cf8faddd24fb0 100644 (file)
@@ -16,6 +16,7 @@
 
 #include <unistd.h>
 #include <assert.h>
+#include "daemon/bindings/impl.h"
 #include "daemon/network.h"
 #include "daemon/worker.h"
 #include "daemon/io.h"
@@ -26,6 +27,7 @@ void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
        if (net != NULL) {
                net->loop = loop;
                net->endpoints = map_make(NULL);
+               net->endpoint_kinds = trie_create(NULL);
                net->tls_client_params = NULL;
                net->tls_session_ticket_ctx = /* unsync. random, by default */
                tls_session_ticket_ctx_create(loop, NULL, 0);
@@ -35,26 +37,136 @@ void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
        }
 }
 
-static void close_handle(uv_handle_t *handle, bool force)
+/** Notify the registered function about endpoint getting open.
+ * If log_port < 1, don't log it. */
+static int endpoint_open_lua_cb(struct network *net, struct endpoint *ep,
+                               const char *log_addr, int log_port)
 {
+       const bool ok = ep->flags.kind && !ep->handle && !ep->engaged && ep->fd != -1;
+       if (!ok) {
+               assert(!EINVAL);
+               return kr_error(EINVAL);
+       }
+       /* First find callback in the endpoint registry. */
+       struct worker_ctx *worker = net->loop->data; // LATER: the_worker
+       lua_State *L = worker->engine->L;
+       void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind,
+                               strlen(ep->flags.kind));
+       if (!pp && net->missing_kind_is_error) {
+               kr_log_error("warning: network socket kind '%s' not handled when opening '%s",
+                               ep->flags.kind, log_addr);
+               if (log_port >= 0)
+                       kr_log_error("#%d", log_port);
+               kr_log_error("'.  Likely causes: typo or not loading 'http' module.\n");
+               /* No hard error, for now.  LATER: perhaps differentiate between
+                * explicit net.listen() calls and "just unused" systemd sockets.
+               return kr_error(ENOENT);
+               */
+       }
+       if (!pp) return kr_ok();
+
+       /* Now execute the callback. */
+       const int fun_id = (char *)*pp - (char *)NULL;
+       lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
+       lua_pushboolean(L, true /* open */);
+       lua_pushpointer(L, ep);
+       if (log_port < 0) {
+               lua_pushstring(L, log_addr);
+       } else {
+               lua_pushfstring(L, "%s#%d", log_addr, log_port);
+       }
+       if (lua_pcall(L, 3, 0, 0)) {
+               kr_log_error("error opening %s: %s\n", log_addr, lua_tostring(L, -1));
+               return kr_error(ENOSYS); /* TODO: better value? */
+       }
+       ep->engaged = true;
+       return kr_ok();
+}
+
+static int engage_endpoint_array(const char *key, void *endpoints, void *net)
+{
+       endpoint_array_t *eps = (endpoint_array_t *)endpoints;
+       for (int i = 0; i < eps->len; ++i) {
+               struct endpoint *ep = &eps->at[i];
+               const bool match = !ep->engaged && ep->flags.kind;
+               if (!match) continue;
+               int ret = endpoint_open_lua_cb(net, ep, key, ep->port);
+               if (ret) return ret;
+       }
+       return 0;
+}
+int network_engage_endpoints(struct network *net)
+{
+       if (net->missing_kind_is_error)
+               return kr_ok(); /* maybe weird, but let's make it idempotent */
+       net->missing_kind_is_error = true;
+       int ret = map_walk(&net->endpoints, engage_endpoint_array, net);
+       if (ret) {
+               net->missing_kind_is_error = false; /* avoid the same errors when closing */
+               return ret;
+       }
+       return kr_ok();
+}
+
+
+/** Notify the registered function about endpoint about to be closed. */
+static void endpoint_close_lua_cb(struct network *net, struct endpoint *ep)
+{
+       struct worker_ctx *worker = net->loop->data; // LATER: the_worker
+       lua_State *L = worker->engine->L;
+       void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind,
+                               strlen(ep->flags.kind));
+       if (!pp && net->missing_kind_is_error) {
+               kr_log_error("internal error: missing kind '%s' in endpoint registry\n",
+                               ep->flags.kind);
+               return;
+       }
+       if (!pp) return;
+
+       const int fun_id = (char *)*pp - (char *)NULL;
+       lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
+       lua_pushboolean(L, false /* close */);
+       lua_pushpointer(L, ep);
+       lua_pushstring(L, "FIXME:endpoint-identifier");
+       if (lua_pcall(L, 3, 0, 0)) {
+               kr_log_error("failed to close FIXME:endpoint-identifier: %s\n",
+                               lua_tostring(L, -1));
+       }
+}
+
+static void endpoint_close(struct network *net, struct endpoint *ep, bool force)
+{
+       assert(!ep->handle != !ep->flags.kind);
+       if (ep->flags.kind) { /* Special endpoint. */
+               if (ep->engaged) {
+                       endpoint_close_lua_cb(net, ep);
+               }
+               if (ep->fd > 0) {
+                       close(ep->fd); /* nothing to do with errors */
+               }
+               free_const(ep->flags.kind);
+               return;
+       }
+
        if (force) { /* Force close if event loop isn't running. */
-               uv_os_fd_t fd = 0;
-               if (uv_fileno(handle, &fd) == 0) {
-                       close(fd);
+               if (ep->fd >= 0) {
+                       close(ep->fd);
+               }
+               if (ep->handle) {
+                       ep->handle->loop = NULL;
+                       io_free(ep->handle);
                }
-               handle->loop = NULL;
-               io_free(handle);
        } else { /* Asynchronous close */
-               uv_close(handle, io_free);
+               uv_close(ep->handle, io_free);
        }
 }
 
 /** Endpoint visitor (see @file map.h) */
-static int close_key(const char *key, void *val, void *ext)
+static int close_key(const char *key, void *val, void *net)
 {
        endpoint_array_t *ep_array = val;
        for (int i = 0; i < ep_array->len; ++i) {
-               close_handle(ep_array->at[i].handle, true);
+               endpoint_close(net, &ep_array->at[i], true);
        }
        return 0;
 }
@@ -67,12 +179,30 @@ static int free_key(const char *key, void *val, void *ext)
        return kr_ok();
 }
 
-void network_deinit(struct network *net)
+int kind_unregister(trie_val_t *tv, void *L)
+{
+       int fun_id = (char *)*tv - (char *)NULL;
+       luaL_unref(L, LUA_REGISTRYINDEX, fun_id);
+       return 0;
+}
+
+void network_close_force(struct network *net)
 {
        if (net != NULL) {
-               map_walk(&net->endpoints, close_key, 0);
+               map_walk(&net->endpoints, close_key, net);
                map_walk(&net->endpoints, free_key, 0);
                map_clear(&net->endpoints);
+       }
+}
+
+void network_deinit(struct network *net)
+{
+       if (net != NULL) {
+               network_close_force(net);
+               struct worker_ctx *worker = net->loop->data; // LATER: the_worker
+               trie_apply(net->endpoint_kinds, kind_unregister, worker->engine->L);
+               trie_free(net->endpoint_kinds);
+
                tls_credentials_free(net->tls_credentials);
                tls_client_params_free(net->tls_client_params);
                tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx);
@@ -108,7 +238,8 @@ static int insert_endpoint(struct network *net, const char *addr, struct endpoin
 
 /** Open endpoint protocols.  ep->flags were pre-set. */
 static int open_endpoint(struct network *net, struct endpoint *ep,
-                        const struct sockaddr *sa, int fd)
+                        const struct sockaddr *sa, int fd,
+                        const char *log_addr, uint16_t log_port)
 {
        if ((sa != NULL) == (fd != -1)) {
                assert(!EINVAL);
@@ -122,6 +253,14 @@ static int open_endpoint(struct network *net, struct endpoint *ep,
                fd = io_bind(sa, ep->flags.sock_type);
                if (fd < 0) return fd;
        }
+       ep->fd = fd;
+       if (ep->flags.kind) {
+               /* This EP isn't to be managed internally after binding. */
+               return endpoint_open_lua_cb(net, ep, log_addr, log_port);
+       } else {
+               ep->engaged = true;
+               /* .engaged seems not really meaningful with .kind == NULL, but... */
+       }
 
        if (ep->flags.sock_type == SOCK_DGRAM) {
                if (ep->flags.tls) {
@@ -167,7 +306,8 @@ static struct endpoint * endpoint_get(struct network *net, const char *addr,
        return NULL;
 }
 
-/** \note pass either sa != NULL xor fd != -1 */
+/** \note pass either sa != NULL xor fd != -1;
+ *  \note ownership of flags.* is taken on success. */
 static int create_endpoint(struct network *net, const char *addr_str,
                                uint16_t port, endpoint_flags_t flags,
                                const struct sockaddr *sa, int fd)
@@ -178,27 +318,26 @@ static int create_endpoint(struct network *net, const char *addr_str,
                .port = port,
                .flags = flags,
        };
-       int ret = open_endpoint(net, &ep, sa, fd);
+       int ret = open_endpoint(net, &ep, sa, fd, addr_str, port);
        if (ret == 0) {
                ret = insert_endpoint(net, addr_str, &ep);
        }
        if (ret != 0 && ep.handle) {
-               close_handle(ep.handle, false);
+               endpoint_close(net, &ep, false);
        }
        return ret;
 }
 
-int network_listen_fd(struct network *net, int fd, bool use_tls)
+int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags)
 {
        /* Extract fd's socket type. */
-       endpoint_flags_t flags = { .tls = use_tls };
        socklen_t len = sizeof(flags.sock_type);
        int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len);
        if (ret != 0) {
                return kr_error(errno);
        }
-       if (flags.sock_type == SOCK_DGRAM && use_tls) {
-               assert(!EINVAL);
+       if (flags.sock_type == SOCK_DGRAM && !flags.kind && flags.tls) {
+               assert(!EINVAL); /* Perhaps DTLS some day. */
                return kr_error(EINVAL);
        }
        if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) {
@@ -267,7 +406,7 @@ int network_close(struct network *net, const char *addr, uint16_t port,
        while (i < ep_array->len) {
                struct endpoint *ep = &ep_array->at[i];
                if (endpoint_flags_eq(flags, ep->flags)) {
-                       close_handle(ep->handle, false);
+                       endpoint_close(net, ep, false);
                        array_del(*ep_array, i);
                        matched = true;
                        /* do not advance i */
index 9e8df520edb7375c1cc935e51b7f907497591038..faa3670c0183bb85cedf918f45aeaa05a0afeb32 100644 (file)
@@ -20,6 +20,7 @@
 
 #include "lib/generic/array.h"
 #include "lib/generic/map.h"
+#include "lib/generic/trie.h"
 
 #include <uv.h>
 #include <stdbool.h>
 
 struct engine;
 
-/** Ways to listen for DNS on a port. */
+/** Ways to listen on a socket. */
 typedef struct {
        int sock_type;  /**< SOCK_DGRAM or SOCK_STREAM */
-       bool tls;       /**< only used together with .tcp */
+       bool tls;       /**< only used together with .tcp; TODO: meaningful if kind != NULL? */
+       const char *kind; /**< tag for other types than the three usual */
 } endpoint_flags_t;
 
 static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2)
 {
-       /* memcmp() would typically work, but there's no guarantee. */
-       return f1.sock_type == f2.sock_type && f1.tls == f2.tls;
+       if (f1.sock_type != f2.sock_type)
+               return false;
+       if (f1.kind && f2.kind)
+               return strcasecmp(f1.kind, f2.kind);
+       else
+               return f1.tls == f2.tls && f1.kind == f2.kind;
 }
 
-/** Wrapper for a single socket to listen on. */
+/** Wrapper for a single socket to listen on.
+ * There are two types: normal have handle, special have flags.kind (and never both).
+ */
 struct endpoint {
-       uv_handle_t *handle; /** uv_udp_t or uv_tcp_t */
+       uv_handle_t *handle; /**< uv_udp_t or uv_tcp_t */
+       int fd;
        uint16_t port;
+       bool engaged; /**< to some module or internally */
        endpoint_flags_t flags;
 };
 
@@ -63,6 +73,13 @@ struct network {
         * TODO: trie_t, keyed on *binary* address-port pair. */
        map_t endpoints;
 
+       /** Registry of callbacks for special endpoint kinds (for opening/closing).
+        * Map: kind (lowercased) -> lua function ID converted to void *
+        * The ID is the usual: raw int index in the LUA_REGISTRYINDEX table. */
+       trie_t *endpoint_kinds;
+       /** See network_engage_endpoints() */
+       bool missing_kind_is_error;
+
        struct tls_credentials *tls_credentials;
        tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */
        struct tls_session_ticket_ctx *tls_session_ticket_ctx;
@@ -76,18 +93,30 @@ void network_deinit(struct network *net);
 /** Start listenting on addr#port with flags.
  * \note if we did listen on that combination already,
  *       nothing is done and kr_ok() is returned.
- * \note there's no short-hand to listen both on UDP and TCP. */
+ * \note there's no short-hand to listen both on UDP and TCP.
+ * \note ownership of flags.* is taken on success.  TODO: non-success?
+ */
 int network_listen(struct network *net, const char *addr, uint16_t port,
                   endpoint_flags_t flags);
 
-/** Start listenting on an open file-descriptor. */
-int network_listen_fd(struct network *net, int fd, bool use_tls);
+/** Start listenting on an open file-descriptor.
+ * \note flags.sock_type isn't meaningful here.
+ * \note ownership of flags.* is taken on success.
+ */
+int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags);
 
 /** Stop listening on all addr#port with equal flags.
  * \return kr_error(ENOENT) if nothing matched. */
 int network_close(struct network *net, const char *addr, uint16_t port,
                  endpoint_flags_t flags);
 
+/** Close all endpoints immediately (no waiting for UV loop). */
+void network_close_force(struct network *net);
+
+/** Enforce that all endpoints are registered from now on.
+ * This only does anything with struct endpoint::flags.kind != NULL. */
+int network_engage_endpoints(struct network *net);
+
 int network_set_tls_cert(struct network *net, const char *cert);
 int network_set_tls_key(struct network *net, const char *key);
 void network_new_hostname(struct network *net, struct engine *engine);