--- /dev/null
+install:
+ rm -rf build
+ rm -rf /tmp/kr
+ meson build -Db_sanitize=address --prefix=/tmp/kr --default-library=static
+ ninja -C build
+ ninja install -C build
+
+shm_clean:
+ make -C /home/jetconf/kres-sysrepo/sysrepo/build shm_clean
\ No newline at end of file
'common/sysrepo_conf.h',
'common/string_helper.h',
'common/string_helper.c',
+ 'common/sysrepo.h',
+ 'common/sysrepo.c',
])
c_src_lint += sysrepo_common_src
--- /dev/null
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <string.h>
+#include <signal.h>
+#include <inttypes.h>
+
+#include <uv.h>
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#include "contrib/ccan/asprintf/asprintf.h"
+#include "lib/utils.h"
+#include "sysrepo.h"
+
+
+int valtostr(const sr_val_t *value, char **strval)
+{
+ if (NULL == value) {
+ return 1;
+ }
+
+ switch (value->type) {
+ case SR_CONTAINER_T:
+ case SR_CONTAINER_PRESENCE_T:
+ break;
+ case SR_LIST_T:
+ break;
+ case SR_STRING_T:
+ asprintf(strval, "%s", value->data.string_val);
+ break;
+ case SR_BOOL_T:
+ asprintf(strval, "%s", value->data.bool_val ? "true" : "false");
+ break;
+ case SR_DECIMAL64_T:
+ asprintf(strval, "%g", value->data.decimal64_val);
+ break;
+ case SR_INT8_T:
+ asprintf(strval, "%"PRId8, value->data.int8_val);
+ break;
+ case SR_INT16_T:
+ asprintf(strval, "%"PRId16, value->data.int16_val);
+ break;
+ case SR_INT32_T:
+ asprintf(strval, "%"PRId32, value->data.int32_val);
+ break;
+ case SR_INT64_T:
+ asprintf(strval, "%"PRId64, value->data.int64_val);
+ break;
+ case SR_UINT8_T:
+ asprintf(strval, "%"PRIu8, value->data.uint8_val);
+ break;
+ case SR_UINT16_T:
+ asprintf(strval, "%"PRIu16, value->data.uint16_val);
+ break;
+ case SR_UINT32_T:
+ asprintf(strval, "%"PRIu32, value->data.uint32_val);
+ break;
+ case SR_UINT64_T:
+ asprintf(strval, "%"PRIu64, value->data.uint64_val);
+ break;
+ case SR_IDENTITYREF_T:
+ asprintf(strval, "%s", value->data.identityref_val);
+ break;
+ case SR_INSTANCEID_T:
+ asprintf(strval, "%s", value->data.instanceid_val);
+ break;
+ case SR_BITS_T:
+ asprintf(strval, "%s", value->data.bits_val);
+ break;
+ case SR_BINARY_T:
+ asprintf(strval, "%s", value->data.binary_val);
+ break;
+ case SR_ENUM_T:
+ asprintf(strval, "%s", value->data.enum_val);
+ break;
+ case SR_LEAF_EMPTY_T:
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+static void sysrepo_subscr_finish_closing(uv_handle_t *handle)
+{
+ sysrepo_uv_ctx_t *sysrepo = handle->data;
+ assert(sysrepo);
+ free(sysrepo);
+}
+
+/** Free a event loop subscription. */
+static void sysrepo_subscription_free(sysrepo_uv_ctx_t *sysrepo)
+{
+ sr_disconnect(sysrepo->connection);
+ uv_close((uv_handle_t *)&sysrepo->uv_handle, sysrepo_subscr_finish_closing);
+}
+
+static void sysrepo_subscr_cb_tramp(uv_poll_t *handle, int status, int events)
+{
+ sysrepo_uv_ctx_t *sysrepo = handle->data;
+ sysrepo->callback(sysrepo, status);
+}
+
+static void sysrepo_subscr_cb(sysrepo_uv_ctx_t *sysrepo, int status)
+{
+ if (status) {
+ /* some error */
+ return;
+ }
+ /* normal state */
+ sr_process_events(sysrepo->subscription, sysrepo->session,NULL);
+}
+
+sysrepo_uv_ctx_t *sysrepo_ctx_init()
+{
+ int ret = SR_ERR_OK;
+ sr_conn_ctx_t *sr_connection = NULL;
+ sr_session_ctx_t *sr_session = NULL;
+ sr_subscription_ctx_t *sr_subscription = NULL;
+
+ if (!ret) ret = sr_connect(0, &sr_connection);
+ if (!ret) ret = sr_session_start(sr_connection, SR_DS_RUNNING, &sr_session);
+ if (ret){
+ kr_log_error(
+ "[sysrepo] failed to start sysrepo session: %s\n",
+ sr_strerror(ret));
+ return NULL;
+ }
+
+ sysrepo_uv_ctx_t *sysrepo = malloc(sizeof(sysrepo_uv_ctx_t));
+ sysrepo->connection = sr_connection;
+ sysrepo->session = sr_session;
+ sysrepo->callback = sysrepo_subscr_cb;
+ sysrepo->subscription = sr_subscription;
+
+ return sysrepo;
+}
+
+int sysrepo_ctx_start(uv_loop_t *loop, sysrepo_uv_ctx_t *sysrepo)
+{
+ int ret = SR_ERR_OK;
+
+ int pipe;
+ ret = sr_get_event_pipe(sysrepo->subscription, &pipe);
+ if (ret != SR_ERR_OK) {
+ kr_log_error("[sysrepo] failed to get sysrepo event pipe: %s\n", sr_strerror(ret));
+ free(sysrepo);
+ return ret;
+ }
+ ret = uv_poll_init(loop, &sysrepo->uv_handle, pipe);
+ if (ret) {
+ kr_log_error("[libuv] failed to initialize uv_poll: %s\n", uv_strerror(ret));
+ free(sysrepo);
+ return ret;
+ }
+ sysrepo->uv_handle.data = sysrepo;
+ ret = uv_poll_start(&sysrepo->uv_handle, UV_READABLE, sysrepo_subscr_cb_tramp);
+ if (ret) {
+ kr_log_error("[libuv] failed to start uv_poll: %s\n", uv_strerror(ret));
+ sysrepo_subscription_free(sysrepo);
+ }
+ return ret;
+}
+
+int sysrepo_ctx_deinit(sysrepo_uv_ctx_t *sysrepo)
+{
+ sysrepo_subscription_free(sysrepo);
+
+ return 0;
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <uv.h>
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#define YM_COMMON "cznic-resolver-common"
+#define YM_KRES "cznic-resolver-knot"
+#define XPATH_BASE "/" YM_COMMON ":dns-resolver"
+#define XPATH_RPC_BASE "/"YM_COMMON
+#define XPATH_GC XPATH_BASE "/cache/" YM_KRES ":garbage-collector"
+#define XPATH_LOG XPATH_BASE "logging/" YM_KRES "log"
+
+
+typedef struct sysrepo_uv_ctx sysrepo_uv_ctx_t;
+/** Callback for sysrepo subscriptions */
+typedef void (*sysrepo_cb)(sysrepo_uv_ctx_t *sysrepo, int status);
+
+/** Context for sysrepo subscriptions.
+ * might add some other fields in future */
+struct sysrepo_uv_ctx {
+ sr_conn_ctx_t *connection;
+ sr_session_ctx_t *session;
+ sr_subscription_ctx_t *subscription;
+ sysrepo_cb callback;
+ uv_poll_t uv_handle;
+};
+
+/** Init sysrepo context */
+sysrepo_uv_ctx_t *sysrepo_ctx_init();
+
+/** Start subscribtion with sysrepo context */
+int sysrepo_ctx_start(uv_loop_t *loop, sysrepo_uv_ctx_t *sysrepo);
+
+/** Destroy sysrepo context */
+int sysrepo_ctx_deinit(sysrepo_uv_ctx_t *sysrepo);
+
+/** Logging funcion */
+int send_log();
+
+/** Convert node's value to string */
+int valtostr(const sr_val_t *value, char **strval);
+
+
+
'common/sysrepo_conf.h',
'common/string_helper.h',
'common/string_helper.c',
+ 'common/sysrepo.h',
+ 'common/sysrepo.c',
])
c_src_lint += sysrepo_common_src
-if build_sysrepo
- sysrepo_mod = shared_module(
- 'sysrepo',
- sysrepo_src,
- sysrepo_common_src,
- dependencies: [
- luajit_inc,
- libyang,
- libsysrepo,
- ],
- include_directories: mod_inc_dir,
- name_prefix: '',
- install: true,
- install_dir: modules_dir,
- )
-endif
\ No newline at end of file
+# if build_sysrepo
+# sysrepo_mod = shared_module(
+# 'sysrepo',
+# sysrepo_src,
+# sysrepo_common_src,
+# dependencies: [
+# luajit_inc,
+# libyang,
+# libsysrepo,
+# ],
+# include_directories: mod_inc_dir,
+# name_prefix: '',
+# install: true,
+# install_dir: modules_dir,
+# )
+# endif
\ No newline at end of file
endif
-if build_client
- kresc = executable(
- 'kresc',
- kresc_src,
- dependencies: [
- contrib_dep,
- libedit,
- ],
- install: true,
- install_dir: get_option('sbindir'),
- )
-endif
+# if build_client
+# kresc = executable(
+# 'kresc',
+# kresc_src,
+# dependencies: [
+# contrib_dep,
+# libedit,
+# ],
+# install: true,
+# install_dir: get_option('sbindir'),
+# )
+# endif
+++ /dev/null
-int main(int argc, char *argv[])
-{
- return 0;
-}
\ No newline at end of file
+++ /dev/null
-kres_watcher_src = files([
- 'main.c',
-])
-c_src_lint += kres_watcher_src
-
-if build_sysrepo
- kres_watcher = executable(
- 'kres-watcher',
- kres_watcher_src,
- sysrepo_common_src,
- dependencies: [
- contrib_dep,
- libkres_dep,
- libyang,
- libsysrepo,
- libknot,
- luajit_inc,
- ],
- install: true,
- install_dir: get_option('sbindir'),
- )
-endif
\ No newline at end of file
--- /dev/null
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <assert.h>
+#include <errno.h>
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#include "lib/generic/array.h"
+#include "modules/sysrepo/common/string_helper.h"
+#include "modules/sysrepo/common/sysrepo.h"
+#include "commands.h"
+#include "process.h"
+#include "conf_file.h"
+
+/* Build-in commands */
+#define CMD_EXIT "exit"
+#define CMD_HELP "help"
+#define CMD_VERSION "version"
+#define CMD_IMPORT "import"
+#define CMD_EXPORT "export"
+#define CMD_BEGIN "begin"
+#define CMD_COMMIT "commit"
+#define CMD_ABORT "abort"
+#define CMD_VALIDATE "validate"
+#define CMD_DIFF "diff"
+#define CMD_PERSIST "persist"
+
+
+static int cmd_import(cmd_args_t *args)
+{
+ struct lyd_node *data;
+ sr_session_ctx_t *sr_session = NULL;
+ const char *file_path = args->argv[0];
+ int flags = LYD_OPT_CONFIG | LYD_OPT_TRUSTED | LYD_OPT_STRICT;
+
+ int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sr_session);
+ if (ret) {
+ printf("failed to start sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ if (!ret) ret = step_load_data(sr_session, file_path, flags, &data);
+
+
+ /* replace config (always spends data) */
+ ret = sr_replace_config(sr_session, YM_COMMON, data, 0, 0);
+ if (ret) {
+ printf("failed to replace configuration, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ ret = sr_session_stop(sr_session);
+ if (ret) {
+ printf("failed to stop sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ return CLI_EOK;
+}
+
+static int cmd_export(cmd_args_t *args)
+{
+ char *xpath;
+ struct lyd_node *data;
+ sr_session_ctx_t *sr_session = NULL;
+ FILE *file = NULL;
+
+ /* If argument, open file for writting */
+ if (args->argc == 1) {
+ file = fopen(args->argv[0], "w");
+ if (!file) {
+ printf("Failed to open \"%s\" for writing (%s)", args->argv[0], strerror(errno));
+ return CLI_ECMD;
+ }
+ }
+
+ asprintf(&xpath, "/%s:*", YM_COMMON);
+ int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sr_session);
+ if (!ret) ret = sr_get_data(sr_session, xpath, 0, 0, 0, &data);
+ if (ret) {
+ printf("failed to get configuration from sysrepo, %s\n", sr_strerror(ret));
+ free(xpath);
+ return CLI_ECMD;
+ }
+
+ ret = sr_session_stop(sr_session);
+ if (ret) {
+ printf("failed to stop sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ /* print exported data to file or stdout */
+ lyd_print_file(file ? file : stdout, data, LYD_JSON, LYP_FORMAT | LYP_WITHSIBLINGS);
+ lyd_free_withsiblings(data);
+ free(xpath);
+
+ return CLI_EOK;
+}
+
+static int cmd_begin(cmd_args_t *args)
+{
+ sr_session_ctx_t *sr_session = NULL;
+
+ if (sysrepo_ctx->session) {
+ printf("transaction has already begin\n");
+ return CLI_ECMD;
+ }
+
+ int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_CANDIDATE, &sr_session);
+ if (ret) {
+ printf("failed to start sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ ret = sr_lock(sr_session, YM_COMMON);
+ if (ret) {
+ printf("failed to lock candidate datastore, %s\n", sr_strerror(ret));
+ sr_session_stop(sr_session);
+ return CLI_ECMD;
+ }
+ sysrepo_ctx->session = sr_session;
+
+ return CLI_EOK;
+}
+
+static int cmd_commmit(cmd_args_t *args)
+{
+ if (!sysrepo_ctx->session){
+ printf("no active transaction\n");
+ return CLI_ECMD;
+ }
+
+ int ret = sr_validate(sysrepo_ctx->session, YM_COMMON, 0);
+ if (ret) {
+ printf("validation failed, %s\n", sr_strerror(ret));
+ }
+
+ /* switch datastore to RUNNING */
+ ret = sr_session_switch_ds(sysrepo_ctx->session, SR_DS_RUNNING);
+ /* copy configuration from CANDIDATE to RUNNING datastore */
+ if (!ret) ret = sr_copy_config(sysrepo_ctx->session, YM_COMMON,
+ SR_DS_CANDIDATE, args->timeout, 0);
+ if (ret) {
+ printf("commit failed, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ ret = sr_session_stop(sysrepo_ctx->session);
+ if (ret) {
+ printf("failed to stop sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+ sysrepo_ctx->session = NULL;
+
+ return CLI_EOK;
+}
+
+static int cmd_abort(cmd_args_t *args)
+{
+ if (!sysrepo_ctx->session){
+ printf("no active transaction\n");
+ return 1;
+ }
+
+ int ret = sr_session_stop(sysrepo_ctx->session);
+ if (ret) {
+ printf("failed to stop sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+ sysrepo_ctx->session = NULL;
+
+ return CLI_EOK;
+}
+
+static int cmd_validate(cmd_args_t *args)
+{
+ if (!sysrepo_ctx->session){
+ printf("no active transaction\n");
+ return CLI_ECMD;
+ }
+
+ int ret = sr_validate(sysrepo_ctx->session, YM_COMMON, args->timeout);
+ if (ret) {
+ printf("validation failed, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+ return CLI_EOK;
+}
+
+static int cmd_diff(cmd_args_t *args)
+{
+ return CLI_EOK;
+}
+
+static int cmd_persist(cmd_args_t *args)
+{
+ sr_session_ctx_t *sr_session = NULL;
+
+ int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_STARTUP, &sr_session);
+ if (ret) {
+ printf("failed to start sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ /* copy configuration from RUNNING to STARTUP datastore */
+ if (!ret) ret = sr_copy_config(sr_session, YM_COMMON, SR_DS_RUNNING, 0, 0);
+ if (ret) {
+ printf("commit failed, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ ret = sr_session_stop(sr_session);
+ if (ret) {
+ printf("failed to stop sysrepo session, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+
+ return CLI_EOK;
+}
+
+/* Funtcions for YANG commands */
+
+static int cmd_leaf(cmd_args_t *args)
+{
+ int ret = CLI_EOK;
+ const char *path = args->desc->xpath;
+
+ /* get operational data */
+ if (!args->argc) {
+
+ char *strval;
+ sr_val_t *val = NULL;
+ sr_session_ctx_t *sr_session = NULL;
+
+ if (!ret) ret = sr_session_start(sysrepo_ctx->connection, SR_DS_OPERATIONAL, &sr_session);
+ if (!ret) ret = sr_get_item(sr_session, path, args->timeout, &val);
+ if (ret) {
+ printf("get configuration data failed, %s\n", sr_strerror(ret));
+ sr_session_stop(sr_session);
+ return CLI_ECMD;
+ }
+
+ valtostr(val, &strval);
+ printf("%s = %s\n", args->desc->name, strval);
+
+ sr_session_stop(sr_session);
+ sr_free_val(val);
+ free(strval);
+ }
+ /* set configuration data */
+ else if (args->argc == 1) {
+ /* check if there is active session */
+ if (!sysrepo_ctx->session) {
+ printf("no active transaction\n");
+ return CLI_ECMD;
+ }
+
+ if (!ret) ret = sr_session_switch_ds(sysrepo_ctx->session,SR_DS_CANDIDATE);
+ if (!ret) ret = sr_set_item_str(sysrepo_ctx->session, path, args->argv[0], NULL, 0);
+ if (!ret) ret = sr_apply_changes(sysrepo_ctx->session, 0, 0);
+ if (ret) {
+ printf("set data value failed, %s\n", sr_strerror(ret));
+ return CLI_ECMD;
+ }
+ }
+ else {
+ printf("too many arguments\n");
+ return CLI_ECMD;
+ }
+
+ return CLI_EOK;
+}
+
+static int cmd_container(cmd_args_t *args)
+{
+ int ret = CLI_EOK;
+ const char *path = args->desc->xpath;
+
+ /* get operational data */
+ if (!args->argc) {
+
+ char *xpath;
+ struct lyd_node *data = NULL;
+ sr_session_ctx_t *sr_session = NULL;
+
+ asprintf(&xpath, "%s/*//.", args->desc->xpath);
+ if (!ret) ret = sr_session_start(sysrepo_ctx->connection, SR_DS_OPERATIONAL, &sr_session);
+ if (!ret) ret = sr_get_data(sr_session, xpath, 0, args->timeout, 0, &data);
+ if (ret) {
+ printf("get configuration data failed, %s\n", sr_strerror(ret));
+ sr_session_stop(sr_session);
+ return CLI_ECMD;
+ }
+
+ lyd_print_file(stdout, data, LYD_JSON, LYP_FORMAT | LYP_WITHSIBLINGS);
+
+ lyd_free_withsiblings(data);
+ sr_session_stop(sr_session);
+ free(xpath);
+ }
+ else {
+ printf("too many arguments\n");
+ return CLI_ECMD;
+ }
+
+ return CLI_EOK;
+}
+
+static int cmd_list(cmd_args_t *args)
+{
+ return CLI_EOK;
+}
+
+static int cmd_leaflist(cmd_args_t *args)
+{
+ return CLI_EOK;
+}
+
+static int cmd_rpc(cmd_args_t *args)
+{
+ int ret = CLI_EOK;
+ sr_session_ctx_t *sr_session = NULL;
+ sr_val_t *output = NULL;
+ size_t output_count = 0;
+
+ //TODO: prepare input
+
+ ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sr_session);
+ if (!ret) ret = sr_rpc_send(sr_session, args->desc->xpath, 0, 0, args->timeout, &output, &output_count);
+ if (ret) {
+ printf("[] failed to send RPC operation, %s\n", sr_strerror(ret));
+ sr_session_stop(sr_session);
+ return CLI_ECMD;
+ }
+
+ sr_free_values(output, output_count);
+ sr_session_stop(sr_session);
+ return CLI_EOK;
+}
+
+static int cmd_notif(cmd_args_t *args)
+{
+ return CLI_EOK;
+}
+
+
+static void cmd_dynarray_deep_free(cmd_dynarray_t * d)
+{
+ dynarray_foreach(cmd, cmd_desc_t *, i, *d) {
+ cmd_desc_t *cmd = *i;
+ free(cmd->xpath);
+ free(cmd->name);
+ free(cmd);
+ }
+ cmd_dynarray_free(d);
+}
+
+cmd_dynarray_t dyn_cmd_table;
+
+const cmd_desc_t cmd_table[] = {
+ /* name, function, xpath, flags */
+ { CMD_EXIT, NULL, "", CMD_FNONE },
+ { CMD_HELP, print_commands, "", CMD_FNONE },
+ { CMD_VERSION, print_version, "", CMD_FNONE },
+ /* Configuration file */
+ { CMD_IMPORT, cmd_import, "", CMD_FNONE },
+ { CMD_EXPORT, cmd_export, "", CMD_FNONE },
+ /* Transaction */
+ { CMD_BEGIN, cmd_begin, "", CMD_FINTER },
+ { CMD_COMMIT, cmd_commmit, "", CMD_FINTER },
+ { CMD_ABORT, cmd_abort, "", CMD_FINTER },
+ { CMD_VALIDATE, cmd_validate, "", CMD_FINTER },
+ { CMD_DIFF, cmd_diff, "", CMD_FINTER },
+ { CMD_PERSIST, cmd_persist, "", CMD_FNONE },
+ /* */
+ { NULL }
+};
+
+static void cmd_help_dynarray_deep_free(cmd_help_dynarray_t * d)
+{
+ dynarray_foreach(cmd_help, cmd_help_t *, i, *d) {
+ cmd_help_t *cmd_help = *i;
+ free(cmd_help);
+ }
+ cmd_help_dynarray_free(d);
+}
+
+cmd_help_dynarray_t dyn_cmd_help_table;
+
+const cmd_help_t cmd_help_table[] = {
+ /* name, arguments, description */
+ { CMD_EXIT, "", "Exit the program." },
+ { CMD_HELP, "", "Print the program help." },
+ { CMD_VERSION, "", "Print the program version." },
+ { "", "", "" },
+ { CMD_IMPORT, "<file-path>", "Import YAML configuration file." },
+ { CMD_EXPORT, "<file-path>", "Export YAML configuration file." },
+ { "", "", "" },
+ { CMD_BEGIN, "", "Begin a transaction." },
+ { CMD_COMMIT, "", "Commit a transaction." },
+ { CMD_ABORT, "", "Abort a transaction." },
+ { CMD_VALIDATE, "", "Validate a transaction changes." },
+ { CMD_DIFF, "", "Show configuration changes." },
+ { CMD_PERSIST, "", "Make running configuration persist during system reboots." },
+ { NULL }
+};
+
+static const char *create_cmd_name(const char* xpath)
+{
+ char* name = (char*)malloc(strlen(xpath)+1);
+ if (!name){
+ // memory allocation failed.
+ return "";
+ }
+ strcpy(name,xpath);
+
+ /* remove modules from name, the order is important */
+ remove_substr(name, XPATH_BASE"/");
+ remove_substr(name, XPATH_BASE);
+ remove_substr(name, "/"YM_COMMON":");
+ remove_substr(name, YM_COMMON":");
+ remove_substr(name, YM_KRES":");
+ /* replace '/' with '.' */
+ replace_char(name, '/', '.');
+
+ return name;
+}
+
+static int create_cmd(struct lys_node *node)
+{
+ if (node->nodetype == LYS_GROUPING ||
+ node->nodetype == LYS_USES) {
+ return 0;
+ }
+
+ const char *xpath = lys_data_path(node);
+ const char *name = create_cmd_name(xpath);
+
+ if (!strlen(name)) {
+ free(xpath);
+ free(name);
+ return 0;
+ }
+
+ cmd_desc_t *cmd = malloc(sizeof(cmd_desc_t));
+ cmd->name = name;
+ cmd->xpath = xpath;
+ cmd->fcn = &cmd_leaf;
+ cmd->flags = CMD_FNONE;
+
+ switch (node->nodetype) {
+ case LYS_CONTAINER:
+ cmd->fcn = &cmd_container;
+ break;
+ case LYS_LEAF:
+ cmd->fcn = &cmd_leaf;
+ break;
+ case LYS_LEAFLIST:
+ cmd->fcn = &cmd_leaflist;
+ break;
+ case LYS_LIST:
+ cmd->fcn = &cmd_list;
+ break;
+ case LYS_ACTION:
+ cmd->fcn = &cmd_rpc;
+ break;
+ case LYS_RPC:
+ cmd->fcn = &cmd_rpc;
+ break;
+ case LYS_NOTIF:
+ cmd->fcn = &cmd_notif;
+ break;
+ default:
+ cmd->fcn = &cmd_leaf;
+ break;
+ }
+
+ cmd_help_t *cmd_help = malloc(sizeof(cmd_help_t));
+ cmd_help->name = name;
+ cmd_help->desc = node->dsc;
+ cmd_help->params = "";
+
+ cmd_help_dynarray_add(&dyn_cmd_help_table, &cmd_help);
+ cmd_dynarray_add(&dyn_cmd_table, &cmd);
+
+ return CLI_EOK;
+}
+
+static void schema_iterator(struct lys_node *root)
+{
+ assert(root != NULL);
+
+ struct lys_node *node = NULL;
+
+ LY_TREE_FOR(root, node) {
+ assert(node != NULL);
+
+ create_cmd(node);
+
+ /* do childs only for CONTAINERS, ignore others */
+ if (node->child
+ && (node->nodetype != LYS_LIST)
+ && (node->nodetype != LYS_RPC)
+ && (node->nodetype != LYS_ACTION)
+ && (node->nodetype != LYS_NOTIF)
+ ) {
+ schema_iterator(node->child);
+ }
+ }
+}
+
+int create_cmd_table(sr_conn_ctx_t *sr_connection)
+{
+ assert(sr_connection != NULL);
+
+ int ret = CLI_EOK;
+ struct lys_node *root = NULL;
+ struct ly_ctx *ly_context = NULL;
+ struct lys_module *module = NULL;
+
+ ly_context = sr_get_context(sr_connection);
+ if (!ly_context) {
+ printf("[] failed to get libyang context\n");
+ return CLI_ERR;
+ }
+
+ /* get libyang context */
+ root = ly_ctx_get_node(ly_context, NULL, XPATH_BASE, 0);
+ assert(root != NULL);
+
+ /* iterate thrue all schema nodes */
+ schema_iterator(root);
+
+ return CLI_EOK;
+}
+
+void destroy_cmd_table()
+{
+ cmd_dynarray_deep_free(&dyn_cmd_table);
+ cmd_help_dynarray_deep_free(&dyn_cmd_help_table);
+}
+
+int print_version(cmd_args_t *args)
+{
+ printf("%s (%s), version %s\n", PROGRAM_NAME, PROJECT_NAME, PACKAGE_VERSION);
+
+ return 0;
+}
+
+int print_commands(cmd_args_t *args)
+{
+ printf("\nCommands:\n");
+
+ /* Print all build-in commands */
+ for (const cmd_help_t *cmd = cmd_help_table; cmd->name != NULL; cmd++) {
+ printf(" %-15s %-15s %s\n", cmd->name, cmd->params, cmd->desc);
+ }
+ printf("\n");
+
+ /* Print all created commands */
+ dynarray_foreach(cmd_help, cmd_help_t *, i, dyn_cmd_help_table) {
+ cmd_help_t *cmd_help = *i;
+ printf(" %-40s %-10s %s\n", cmd_help->name, cmd_help->params, cmd_help->desc);
+ }
+
+ printf("\n"
+ "Note:\n"
+ "");
+
+ return 0;
+}
--- /dev/null
+#pragma once
+
+#include <stdbool.h>
+#include <string.h>
+#include <sysrepo.h>
+
+#include "contrib/dynarray.h"
+#include "process.h"
+
+
+struct cmd_desc;
+typedef struct cmd_desc cmd_desc_t;
+typedef struct cmd_help cmd_help_t;
+
+typedef enum {
+ CMD_FNONE = 0,
+ CMD_FINTER = 1 << 0, /* Interactive-only command. */
+ CMD_FSTATE = 2 << 0, /* State-only command */
+} cmd_flag_t;
+
+typedef struct {
+ const cmd_desc_t *desc;
+ const sr_datastore_t ds;
+ int argc;
+ const char **argv;
+ int timeout;
+} cmd_args_t;
+
+struct cmd_desc {
+ const char *name;
+ int (*fcn)(cmd_args_t *);
+ const char *xpath;
+ cmd_flag_t flags;
+};
+
+struct cmd_help {
+ const char *name;
+ const char *params;
+ const char *desc;
+};
+
+dynarray_declare(cmd, cmd_desc_t *, DYNARRAY_VISIBILITY_STATIC, 0)
+ dynarray_define(cmd, cmd_desc_t *, DYNARRAY_VISIBILITY_STATIC)
+
+dynarray_declare(cmd_help, cmd_help_t *, DYNARRAY_VISIBILITY_STATIC, 0)
+ dynarray_define(cmd_help, cmd_help_t *, DYNARRAY_VISIBILITY_STATIC)
+
+int create_cmd_table(sr_conn_ctx_t *sr_connection);
+
+void destroy_cmd_table();
+
+int print_version(cmd_args_t *args);
+
+int print_commands(cmd_args_t *args);
+
+extern cmd_dynarray_t dyn_cmd_table;
+extern cmd_help_dynarray_t dyn_cmd_help_table;
+
+extern const cmd_desc_t cmd_table[];
+extern const cmd_help_t cmd_help_table[];
\ No newline at end of file
--- /dev/null
+#include <stdio.h>
+#include <errno.h>
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#include "conf_file.h"
+
+
+static int step_read_file(FILE *file, char **mem)
+{
+ size_t mem_size, mem_used;
+
+ mem_size = 512;
+ mem_used = 0;
+ *mem = malloc(mem_size);
+
+ do {
+ if (mem_used == mem_size) {
+ mem_size >>= 1;
+ *mem = realloc(*mem, mem_size);
+ }
+
+ mem_used += fread(*mem + mem_used, 1, mem_size - mem_used, file);
+ } while (mem_used == mem_size);
+
+ if (ferror(file)) {
+ free(*mem);
+ printf("Error reading from file (%s)\n", strerror(errno));
+ return EXIT_FAILURE;
+ } else if (!feof(file)) {
+ free(*mem);
+ printf("Unknown file problem\n");
+ return EXIT_FAILURE;
+ }
+
+ return EXIT_SUCCESS;
+}
+
+int step_load_data(sr_session_ctx_t *sr_session, const char *file_path, int flags, struct lyd_node **data)
+{
+ struct ly_ctx *ly_ctx;
+ char *ptr;
+
+ ly_ctx = (struct ly_ctx *)sr_get_context(sr_session_get_connection(sr_session));
+
+ /* parse import data */
+ if (file_path) {
+ *data = lyd_parse_path(ly_ctx, file_path, LYD_JSON, flags, NULL);
+ } else {
+ /* need to load the data into memory first */
+ if (step_read_file(stdin, &ptr)) {
+ return EXIT_FAILURE;
+ }
+ *data = lyd_parse_mem(ly_ctx, ptr, LYD_JSON, flags);
+ free(ptr);
+ }
+ if (ly_errno) {
+ printf("Data parsing failed\n");
+ return EXIT_FAILURE;
+ }
+
+ return EXIT_SUCCESS;
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+
+int step_load_data(sr_session_ctx_t *sess, const char *file_path, int flags, struct lyd_node **data);
\ No newline at end of file
--- /dev/null
+/* Copyright (C) 2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/*!
+ * \brief Locale-independent ctype functions.
+ */
+
+#pragma once
+
+#include <ctype.h>
+#include <stdbool.h>
+#include <stdint.h>
+
+enum {
+ CT_DIGIT = 1 << 0,
+ CT_UPPER = 1 << 1,
+ CT_LOWER = 1 << 2,
+ CT_XDIGT = 1 << 3,
+ CT_PUNCT = 1 << 4,
+ CT_PRINT = 1 << 5,
+ CT_SPACE = 1 << 6,
+};
+
+static const uint8_t char_mask[256] = {
+ // 0 - 8
+ ['\t'] = CT_SPACE,
+ ['\n'] = CT_SPACE,
+ ['\v'] = CT_SPACE,
+ ['\f'] = CT_SPACE,
+ ['\r'] = CT_SPACE,
+ // 14 - 31
+ [' '] = CT_PRINT | CT_SPACE,
+
+ ['!'] = CT_PRINT | CT_PUNCT,
+ ['"'] = CT_PRINT | CT_PUNCT,
+ ['#'] = CT_PRINT | CT_PUNCT,
+ ['$'] = CT_PRINT | CT_PUNCT,
+ ['%'] = CT_PRINT | CT_PUNCT,
+ ['&'] = CT_PRINT | CT_PUNCT,
+ ['\''] = CT_PRINT | CT_PUNCT,
+ ['('] = CT_PRINT | CT_PUNCT,
+ [')'] = CT_PRINT | CT_PUNCT,
+ ['*'] = CT_PRINT | CT_PUNCT,
+ ['+'] = CT_PRINT | CT_PUNCT,
+ [','] = CT_PRINT | CT_PUNCT,
+ ['-'] = CT_PRINT | CT_PUNCT,
+ ['.'] = CT_PRINT | CT_PUNCT,
+ ['/'] = CT_PRINT | CT_PUNCT,
+
+ ['0'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['1'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['2'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['3'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['4'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['5'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['6'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['7'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['8'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+ ['9'] = CT_PRINT | CT_DIGIT | CT_XDIGT,
+
+ [':'] = CT_PRINT | CT_PUNCT,
+ [';'] = CT_PRINT | CT_PUNCT,
+ ['<'] = CT_PRINT | CT_PUNCT,
+ ['='] = CT_PRINT | CT_PUNCT,
+ ['>'] = CT_PRINT | CT_PUNCT,
+ ['?'] = CT_PRINT | CT_PUNCT,
+ ['@'] = CT_PRINT | CT_PUNCT,
+
+ ['A'] = CT_PRINT | CT_UPPER | CT_XDIGT,
+ ['B'] = CT_PRINT | CT_UPPER | CT_XDIGT,
+ ['C'] = CT_PRINT | CT_UPPER | CT_XDIGT,
+ ['D'] = CT_PRINT | CT_UPPER | CT_XDIGT,
+ ['E'] = CT_PRINT | CT_UPPER | CT_XDIGT,
+ ['F'] = CT_PRINT | CT_UPPER | CT_XDIGT,
+ ['G'] = CT_PRINT | CT_UPPER,
+ ['H'] = CT_PRINT | CT_UPPER,
+ ['I'] = CT_PRINT | CT_UPPER,
+ ['J'] = CT_PRINT | CT_UPPER,
+ ['K'] = CT_PRINT | CT_UPPER,
+ ['L'] = CT_PRINT | CT_UPPER,
+ ['M'] = CT_PRINT | CT_UPPER,
+ ['N'] = CT_PRINT | CT_UPPER,
+ ['O'] = CT_PRINT | CT_UPPER,
+ ['P'] = CT_PRINT | CT_UPPER,
+ ['Q'] = CT_PRINT | CT_UPPER,
+ ['R'] = CT_PRINT | CT_UPPER,
+ ['S'] = CT_PRINT | CT_UPPER,
+ ['T'] = CT_PRINT | CT_UPPER,
+ ['U'] = CT_PRINT | CT_UPPER,
+ ['V'] = CT_PRINT | CT_UPPER,
+ ['W'] = CT_PRINT | CT_UPPER,
+ ['X'] = CT_PRINT | CT_UPPER,
+ ['Y'] = CT_PRINT | CT_UPPER,
+ ['Z'] = CT_PRINT | CT_UPPER,
+
+ ['['] = CT_PRINT | CT_PUNCT,
+ ['\\'] = CT_PRINT | CT_PUNCT,
+ [']'] = CT_PRINT | CT_PUNCT,
+ ['^'] = CT_PRINT | CT_PUNCT,
+ ['_'] = CT_PRINT | CT_PUNCT,
+ ['`'] = CT_PRINT | CT_PUNCT,
+
+ ['a'] = CT_PRINT | CT_LOWER | CT_XDIGT,
+ ['b'] = CT_PRINT | CT_LOWER | CT_XDIGT,
+ ['c'] = CT_PRINT | CT_LOWER | CT_XDIGT,
+ ['d'] = CT_PRINT | CT_LOWER | CT_XDIGT,
+ ['e'] = CT_PRINT | CT_LOWER | CT_XDIGT,
+ ['f'] = CT_PRINT | CT_LOWER | CT_XDIGT,
+ ['g'] = CT_PRINT | CT_LOWER,
+ ['h'] = CT_PRINT | CT_LOWER,
+ ['i'] = CT_PRINT | CT_LOWER,
+ ['j'] = CT_PRINT | CT_LOWER,
+ ['k'] = CT_PRINT | CT_LOWER,
+ ['l'] = CT_PRINT | CT_LOWER,
+ ['m'] = CT_PRINT | CT_LOWER,
+ ['n'] = CT_PRINT | CT_LOWER,
+ ['o'] = CT_PRINT | CT_LOWER,
+ ['p'] = CT_PRINT | CT_LOWER,
+ ['q'] = CT_PRINT | CT_LOWER,
+ ['r'] = CT_PRINT | CT_LOWER,
+ ['s'] = CT_PRINT | CT_LOWER,
+ ['t'] = CT_PRINT | CT_LOWER,
+ ['u'] = CT_PRINT | CT_LOWER,
+ ['v'] = CT_PRINT | CT_LOWER,
+ ['w'] = CT_PRINT | CT_LOWER,
+ ['x'] = CT_PRINT | CT_LOWER,
+ ['y'] = CT_PRINT | CT_LOWER,
+ ['z'] = CT_PRINT | CT_LOWER,
+
+ ['{'] = CT_PRINT | CT_PUNCT,
+ ['|'] = CT_PRINT | CT_PUNCT,
+ ['}'] = CT_PRINT | CT_PUNCT,
+ ['~'] = CT_PRINT | CT_PUNCT,
+ // 127 - 255
+};
+
+static inline bool is_alnum(uint8_t c)
+{
+ return char_mask[c] & (CT_DIGIT | CT_UPPER | CT_LOWER);
+}
+
+static inline bool is_alpha(uint8_t c)
+{
+ return char_mask[c] & (CT_UPPER | CT_LOWER);
+}
+
+static inline bool is_digit(uint8_t c)
+{
+ return char_mask[c] & CT_DIGIT;
+}
+
+static inline bool is_xdigit(uint8_t c)
+{
+ return char_mask[c] & CT_XDIGT;
+}
+
+static inline bool is_lower(uint8_t c)
+{
+ return char_mask[c] & CT_LOWER;
+}
+
+static inline bool is_upper(uint8_t c)
+{
+ return char_mask[c] & CT_UPPER;
+}
+
+static inline bool is_print(uint8_t c)
+{
+ return char_mask[c] & CT_PRINT;
+}
+
+static inline bool is_punct(uint8_t c)
+{
+ return char_mask[c] & CT_PUNCT;
+}
+
+static inline bool is_space(uint8_t c)
+{
+ return char_mask[c] & CT_SPACE;
+}
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <string.h>
+
+#include "lookup.h"
+#include "mempattern.h"
+#include "contrib/ucw/mempool.h"
+#include "libknot/error.h"
+
+int lookup_init(lookup_t *lookup)
+{
+ if (lookup == NULL) {
+ return KNOT_EINVAL;
+ }
+ memset(lookup, 0, sizeof(*lookup));
+
+ mm_ctx_mempool(&lookup->mm, MM_DEFAULT_BLKSIZE);
+ lookup->trie = trie_create(&lookup->mm);
+ if (lookup->trie == NULL) {
+ mp_delete(lookup->mm.ctx);
+ return KNOT_ENOMEM;
+ }
+
+ return KNOT_EOK;
+}
+
+static void reset_output(lookup_t *lookup)
+{
+ if (lookup == NULL) {
+ return;
+ }
+
+ mm_free(&lookup->mm, lookup->found.key);
+ lookup->found.key = NULL;
+ lookup->found.data = NULL;
+
+ lookup->iter.count = 0;
+
+ mm_free(&lookup->mm, lookup->iter.first_key);
+ lookup->iter.first_key = NULL;
+
+ trie_it_free(lookup->iter.it);
+ lookup->iter.it = NULL;
+}
+
+void lookup_deinit(lookup_t *lookup)
+{
+ if (lookup == NULL) {
+ return;
+ }
+
+ reset_output(lookup);
+
+ trie_free(lookup->trie);
+ mp_delete(lookup->mm.ctx);
+}
+
+int lookup_insert(lookup_t *lookup, const char *str, void *data)
+{
+ if (lookup == NULL || str == NULL) {
+ return KNOT_EINVAL;
+ }
+
+ size_t str_len = strlen(str);
+ if (str_len == 0) {
+ return KNOT_EINVAL;
+ }
+
+ trie_val_t *val = trie_get_ins(lookup->trie, (const trie_key_t *)str, str_len);
+ if (val == NULL) {
+ return KNOT_ENOMEM;
+ }
+ *val = data;
+
+ return KNOT_EOK;
+}
+
+static int set_key(lookup_t *lookup, char **dst, const char *key, size_t key_len)
+{
+ if (*dst != NULL) {
+ mm_free(&lookup->mm, *dst);
+ }
+ *dst = mm_alloc(&lookup->mm, key_len + 1);
+ if (*dst == NULL) {
+ return KNOT_ENOMEM;
+ }
+ memcpy(*dst, key, key_len);
+ (*dst)[key_len] = '\0';
+
+ return KNOT_EOK;
+}
+
+int lookup_search(lookup_t *lookup, const char *str, size_t str_len)
+{
+ if (lookup == NULL) {
+ return KNOT_EINVAL;
+ }
+
+ // Change NULL string to the empty one.
+ if (str == NULL) {
+ str = "";
+ }
+
+ reset_output(lookup);
+
+ size_t new_len = 0;
+ trie_it_t *it = trie_it_begin(lookup->trie);
+ for (; !trie_it_finished(it); trie_it_next(it)) {
+ size_t len;
+ const char *key = (const char *)trie_it_key(it, &len);
+
+ // Compare with a shorter key.
+ if (len < str_len) {
+ int ret = memcmp(str, key, len);
+ if (ret >= 0) {
+ continue;
+ } else {
+ break;
+ }
+ }
+
+ // Compare with an equal length or longer key.
+ int ret = memcmp(str, key, str_len);
+ if (ret == 0) {
+ lookup->iter.count++;
+
+ // First candidate.
+ if (lookup->iter.count == 1) {
+ ret = set_key(lookup, &lookup->found.key, key, len);
+ if (ret != KNOT_EOK) {
+ break;
+ }
+ lookup->found.data = *trie_it_val(it);
+ new_len = len;
+ // Another candidate.
+ } else if (new_len > str_len) {
+ if (new_len > len) {
+ new_len = len;
+ }
+ while (memcmp(lookup->found.key, key, new_len) != 0) {
+ new_len--;
+ }
+ }
+ // Stop if greater than the key, and also than all the following keys.
+ } else if (ret < 0) {
+ break;
+ }
+ }
+ trie_it_free(it);
+
+ switch (lookup->iter.count) {
+ case 0:
+ return KNOT_ENOENT;
+ case 1:
+ return KNOT_EOK;
+ default:
+ // Store full name of the first candidate.
+ if (set_key(lookup, &lookup->iter.first_key, lookup->found.key,
+ strlen(lookup->found.key)) != KNOT_EOK) {
+ return KNOT_ENOMEM;
+ }
+ lookup->found.key[new_len] = '\0';
+ lookup->found.data = NULL;
+
+ return KNOT_EFEWDATA;
+ }
+}
+
+void lookup_list(lookup_t *lookup)
+{
+ if (lookup == NULL || lookup->iter.first_key == NULL) {
+ return;
+ }
+
+ if (lookup->iter.it != NULL) {
+ if (trie_it_finished(lookup->iter.it)) {
+ trie_it_free(lookup->iter.it);
+ lookup->iter.it = NULL;
+ return;
+ }
+
+ trie_it_next(lookup->iter.it);
+
+ size_t len;
+ const char *key = (const char *)trie_it_key(lookup->iter.it, &len);
+
+ int ret = set_key(lookup, &lookup->found.key, key, len);
+ if (ret == KNOT_EOK) {
+ lookup->found.data = *trie_it_val(lookup->iter.it);
+ }
+ return;
+ }
+
+ lookup->iter.it = trie_it_begin(lookup->trie);
+ while (!trie_it_finished(lookup->iter.it)) {
+ size_t len;
+ const char *key = (const char *)trie_it_key(lookup->iter.it, &len);
+
+ if (strncmp(key, lookup->iter.first_key, len) == 0) {
+ int ret = set_key(lookup, &lookup->found.key, key, len);
+ if (ret == KNOT_EOK) {
+ lookup->found.data = *trie_it_val(lookup->iter.it);
+ }
+ break;
+ }
+ trie_it_next(lookup->iter.it);
+ }
+}
+
+static void print_options(lookup_t *lookup, EditLine *el)
+{
+ // Get terminal lines.
+ unsigned lines = 0;
+ if (el_get(el, EL_GETTC, "li", &lines) != 0 || lines < 3) {
+ return;
+ }
+
+ for (size_t i = 1; i <= lookup->iter.count; i++) {
+ lookup_list(lookup);
+ printf("\n%s", lookup->found.key);
+
+ if (i > 1 && i % (lines - 1) == 0 && i < lookup->iter.count) {
+ printf("\n Display next from %zu possibilities? (y or n)",
+ lookup->iter.count);
+ char next;
+ el_getc(el, &next);
+ if (next != 'y') {
+ break;
+ }
+ }
+ }
+
+ printf("\n");
+ fflush(stdout);
+}
+
+void lookup_complete(lookup_t *lookup, const char *str, size_t str_len,
+ EditLine *el, bool add_space)
+{
+ if (lookup == NULL || el == NULL) {
+ return;
+ }
+
+ // Try to complete the command name.
+ int ret = lookup_search(lookup, str, str_len);
+ switch (ret) {
+ case KNOT_EOK:
+ el_deletestr(el, str_len);
+ el_insertstr(el, lookup->found.key);
+ if (add_space) {
+ el_insertstr(el, " ");
+ }
+ break;
+ case KNOT_EFEWDATA:
+ if (strlen(lookup->found.key) > str_len) {
+ el_deletestr(el, str_len);
+ el_insertstr(el, lookup->found.key);
+ } else {
+ print_options(lookup, el);
+ }
+ break;
+ default:
+ break;
+ }
+}
\ No newline at end of file
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <histedit.h>
+
+#include "libknot/mm_ctx.h"
+#include "qp-trie/trie.h"
+
+/*! Lookup context. */
+typedef struct {
+ /*! Memory pool context. */
+ knot_mm_t mm;
+ /*! Main trie storage. */
+ trie_t *trie;
+
+ /*! Current (iteration) data context. */
+ struct {
+ /*! Stored key. */
+ char *key;
+ /*! Corresponding key data. */
+ void *data;
+ } found;
+
+ /*! Iteration context. */
+ struct {
+ /*! Total number of possibilies. */
+ size_t count;
+ /*! The first possibility. */
+ char *first_key;
+ /*! Hat-trie iterator. */
+ trie_it_t *it;
+ } iter;
+} lookup_t;
+
+/*!
+ * Initializes the lookup context.
+ *
+ * \param[in] lookup Lookup context.
+ *
+ * \return Error code, KNOT_EOK if successful.
+ */
+int lookup_init(lookup_t *lookup);
+
+/*!
+ * Deinitializes the lookup context.
+ *
+ * \param[in] lookup Lookup context.
+ */
+void lookup_deinit(lookup_t *lookup);
+
+/*!
+ * Inserts given key and data into the lookup.
+ *
+ * \param[in] lookup Lookup context.
+ * \param[in] str Textual key.
+ * \param[in] data Key textual data.
+ *
+ * \return Error code, KNOT_EOK if successful.
+ */
+int lookup_insert(lookup_t *lookup, const char *str, void *data);
+
+/*!
+ * Searches the lookup container for the given key.
+ *
+ * \note If one candidate, lookup.found contains the key/data,
+ * if more candidates, lookup.found contains the common key prefix and
+ * lookup.iter.first_key is the first candidate key.
+ *
+ * \param[in] lookup Lookup context.
+ * \param[in] str Textual key.
+ * \param[in] str_len Textual key length.
+ *
+ * \return Error code, KNOT_EOK if 1 candidate, KNOT_ENOENT if no candidate,
+ * and KNOT_EFEWDATA if more candidates are possible.
+ */
+int lookup_search(lookup_t *lookup, const char *str, size_t str_len);
+
+/*!
+ * Moves the lookup iterator to the next key candidate.
+ *
+ * \note lookup.found is updated.
+ *
+ * \param[in] lookup Lookup context.
+ */
+void lookup_list(lookup_t *lookup);
+
+/*!
+ * Completes the string based on the lookup content or prints all candidates.
+ *
+ * \param[in] lookup Lookup context.
+ * \param[in] str Textual key.
+ * \param[in] str_len Textual key length.
+ * \param[in] el Editline context.
+ * \param[in] add_space Add one space after completed string flag.
+ */
+void lookup_complete(lookup_t *lookup, const char *str, size_t str_len,
+ EditLine *el, bool add_space);
\ No newline at end of file
--- /dev/null
+/* Copyright (C) 2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <stdlib.h>
+
+#include "mempattern.h"
+#include "string.h"
+#include "ucw/mempool.h"
+
+static void mm_nofree(void *p)
+{
+ /* nop */
+}
+
+static void *mm_malloc(void *ctx, size_t n)
+{
+ (void)ctx;
+ return malloc(n);
+}
+
+void *mm_alloc(knot_mm_t *mm, size_t size)
+{
+ if (mm) {
+ return mm->alloc(mm->ctx, size);
+ } else {
+ return malloc(size);
+ }
+}
+
+void *mm_calloc(knot_mm_t *mm, size_t nmemb, size_t size)
+{
+ if (nmemb == 0 || size == 0) {
+ return NULL;
+ }
+ if (mm) {
+ size_t total_size = nmemb * size;
+ if (total_size / nmemb != size) { // Overflow check
+ return NULL;
+ }
+ void *mem = mm_alloc(mm, total_size);
+ if (mem == NULL) {
+ return NULL;
+ }
+ return memzero(mem, total_size);
+ } else {
+ return calloc(nmemb, size);
+ }
+}
+/*
+void *mm_realloc(knot_mm_t *mm, void *what, size_t size, size_t prev_size)
+{
+ if (mm) {
+ void *p = mm->alloc(mm->ctx, size);
+ if (p == NULL) {
+ return NULL;
+ } else {
+ if (what) {
+ memcpy(p, what,
+ prev_size < size ? prev_size : size);
+ }
+ mm_free(mm, what);
+ return p;
+ }
+ } else {
+ return realloc(what, size);
+ }
+}
+*/
+char *mm_strdup(knot_mm_t *mm, const char *s)
+{
+ if (s == NULL) {
+ return NULL;
+ }
+ if (mm) {
+ size_t len = strlen(s) + 1;
+ void *mem = mm_alloc(mm, len);
+ if (mem == NULL) {
+ return NULL;
+ }
+ return memcpy(mem, s, len);
+ } else {
+ return strdup(s);
+ }
+}
+
+void mm_free(knot_mm_t *mm, void *what)
+{
+ if (mm) {
+ if (mm->free) {
+ mm->free(what);
+ }
+ } else {
+ free(what);
+ }
+}
+
+void mm_ctx_init(knot_mm_t *mm)
+{
+ mm->ctx = NULL;
+ mm->alloc = mm_malloc;
+ mm->free = free;
+}
+
+void mm_ctx_mempool(knot_mm_t *mm, size_t chunk_size)
+{
+ mm->ctx = mp_new(chunk_size);
+ mm->alloc = (knot_mm_alloc_t)mp_alloc;
+ mm->free = mm_nofree;
+}
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/*!
+ * \brief Memory allocation related functions.
+ */
+
+#pragma once
+
+#include "libknot/mm_ctx.h"
+
+/*! \brief Default memory block size. */
+#define MM_DEFAULT_BLKSIZE 4096
+
+/*! \brief Allocs using 'mm' if any, uses system malloc() otherwise. */
+void *mm_alloc(knot_mm_t *mm, size_t size);
+
+/*! \brief Callocs using 'mm' if any, uses system calloc() otherwise. */
+void *mm_calloc(knot_mm_t *mm, size_t nmemb, size_t size);
+
+/*! \brief Reallocs using 'mm' if any, uses system realloc() otherwise. */
+//void *mm_realloc(knot_mm_t *mm, void *what, size_t size, size_t prev_size);
+
+/*! \brief Strdups using 'mm' if any, uses system strdup() otherwise. */
+char *mm_strdup(knot_mm_t *mm, const char *s);
+
+/*! \brief Free using 'mm' if any, uses system free() otherwise. */
+void mm_free(knot_mm_t *mm, void *what);
+
+/*! \brief Initialize default memory allocation context. */
+void mm_ctx_init(knot_mm_t *mm);
+
+/*! \brief Memory pool context. */
+void mm_ctx_mempool(knot_mm_t *mm, size_t chunk_size);
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ Copyright (C) 2018 Tony Finch <dot@dotat.at>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+ The code originated from https://github.com/fanf2/qp/blob/master/qp.c
+ at revision 5f6d93753.
+ */
+
+#include <assert.h>
+#include <limits.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "contrib/qp-trie/trie.h"
+#include "contrib/macros.h"
+#include "contrib/mempattern.h"
+#include "libknot/errcode.h"
+
+typedef unsigned int uint;
+typedef uint64_t index_t; /*!< nibble index into a key */
+typedef uint64_t word; /*!< A type-punned word */
+typedef uint bitmap_t; /*!< Bit-maps, using the range of 1<<0 to 1<<16 (inclusive). */
+
+typedef char static_assert_pointer_fits_in_word
+ [sizeof(word) >= sizeof(uintptr_t) ? 1 : -1];
+
+#define KEYLENBITS 31
+
+/*! \brief trie keys have lengths
+ *
+ * 32 bits are enough for key lengths; probably even 16 bits would be.
+ * However, a 32 bit length means the alignment will be a multiple of
+ * 4, allowing us to stash the COW and BRANCH flags in the bottom bits
+ * of a pointer to a key.
+ *
+ * We need to steal a couple of bits from the length to keep the COW
+ * state of key allocations.
+ */
+typedef struct {
+ uint32_t cow:1, len:KEYLENBITS;
+ trie_key_t chars[];
+} tkey_t;
+
+/*! \brief A trie node is a pair of words.
+ *
+ * Each word is type-punned, depending on whether this is a branch
+ * node or a leaf node. We'll define some accessor functions to wrap
+ * this up into something reasonably safe.
+ *
+ * We aren't using a union to avoid problems with strict aliasing, and
+ * we aren't using bitfields because we want to control exactly which
+ * bits in the word are used by each field (in particular the flags).
+ *
+ * Branch nodes are never allocated individually: they are always part
+ * of either the root node or the twigs array of their parent branch.
+ *
+ * In a branch:
+ *
+ * `i` contains flags, bitmap, and index, explained in more detail below.
+ *
+ * `p` is a pointer to the "twigs", an array of child nodes.
+ *
+ * In a leaf:
+ *
+ * `i` is cast from a pointer to a tkey_t, with flags in the bottom bits.
+ *
+ * `p` is a trie_val_t.
+ */
+typedef struct node {
+ word i;
+ void *p;
+} node_t;
+
+struct trie {
+ node_t root; // undefined when weight == 0, see empty_root()
+ size_t weight;
+ knot_mm_t mm;
+};
+
+/*! \brief size (in bits) of nibble (half-byte) indexes into keys
+ *
+ * The bottom bit is clear for the upper nibble, and set for the lower
+ * nibble, big-endian style, since the tree has to be in lexicographic
+ * order. The index increases from one branch node to the next as you
+ * go deeper into the trie. All the keys below a branch are identical
+ * up to the nibble identified by the branch.
+ *
+ * (see also tkey_t.len above)
+ */
+#define TWIDTH_INDEX 33
+
+/*! \brief exclusive limit on indexes */
+#define TMAX_INDEX (BIG1 << TWIDTH_INDEX)
+
+/*! \brief size (in bits) of branch bitmap
+ *
+ * The bitmap indicates which subtries are present. The present child
+ * nodes are stored in the twigs array (with no holes between them).
+ *
+ * To simplify storing keys that are prefixes of each other, the
+ * end-of-string position is treated as an extra nibble value, ordered
+ * before all others. So there are 16 possible real nibble values,
+ * plus one value for nibbles past the end of the key.
+ */
+#define TWIDTH_BMP 17
+
+/*
+ * We're constructing the layout of the branch `i` field in a careful
+ * way to avoid mistakes, getting the compiler to calculate values
+ * rather than typing them in by hand.
+ */
+enum {
+ TSHIFT_BRANCH = 0,
+ TSHIFT_COW,
+ TSHIFT_BMP,
+ TOP_BMP = TSHIFT_BMP + TWIDTH_BMP,
+ TSHIFT_INDEX = TOP_BMP,
+ TOP_INDEX = TSHIFT_INDEX + TWIDTH_INDEX,
+};
+
+typedef char static_assert_fields_fit_in_word
+ [TOP_INDEX <= sizeof(word) * CHAR_BIT ? 1 : -1];
+
+typedef char static_assert_bmp_fits
+ [TOP_BMP <= sizeof(bitmap_t) * CHAR_BIT ? 1 : -1];
+
+#define BIG1 ((word)1)
+#define TMASK(width, shift) (((BIG1 << (width)) - BIG1) << (shift))
+
+/*! \brief is this node a branch or a leaf? */
+#define TFLAG_BRANCH (BIG1 << TSHIFT_BRANCH)
+
+/*! \brief copy-on-write flag, used in both leaves and branches */
+#define TFLAG_COW (BIG1 << TSHIFT_COW)
+
+/*! \brief for extracting pointer to key */
+#define TMASK_LEAF (~(word)(TFLAG_BRANCH | TFLAG_COW))
+
+/*! \brief mask for extracting nibble index */
+#define TMASK_INDEX TMASK(TWIDTH_INDEX, TSHIFT_INDEX)
+
+/*! \brief mask for extracting bitmap */
+#define TMASK_BMP TMASK(TWIDTH_BMP, TSHIFT_BMP)
+
+/*! \brief bitmap entry for NOBYTE */
+#define BMP_NOBYTE (BIG1 << TSHIFT_BMP)
+
+/*! \brief Initialize a new leaf, copying the key, and returning failure code. */
+static int mkleaf(node_t *leaf, const trie_key_t *key, uint32_t len, knot_mm_t *mm)
+{
+ if (unlikely((word)len > (BIG1 << KEYLENBITS)))
+ return KNOT_ENOMEM;
+ tkey_t *lkey = mm_alloc(mm, sizeof(tkey_t) + len);
+ if (unlikely(!lkey))
+ return KNOT_ENOMEM;
+ lkey->cow = 0;
+ lkey->len = len;
+ memcpy(lkey->chars, key, len);
+ word i = (uintptr_t)lkey;
+ assert((i & TFLAG_BRANCH) == 0);
+ *leaf = (node_t){ .i = i, .p = NULL };
+ return KNOT_EOK;
+}
+
+/*! \brief construct a branch node */
+static node_t mkbranch(index_t index, bitmap_t bmp, node_t *twigs)
+{
+ word i = TFLAG_BRANCH | bmp
+ | (index << TSHIFT_INDEX);
+ assert(index < TMAX_INDEX);
+ assert((bmp & ~TMASK_BMP) == 0);
+ return (node_t){ .i = i, .p = twigs };
+}
+
+/*! \brief Make an empty root node. */
+static node_t empty_root(void)
+{
+ return mkbranch(TMAX_INDEX-1, 0, NULL);
+}
+
+/*! \brief Propagate error codes. */
+#define ERR_RETURN(x) \
+ do { \
+ int err_code_ = x; \
+ if (unlikely(err_code_ != KNOT_EOK)) \
+ return err_code_; \
+ } while (false)
+
+
+/*! \brief Test flags to determine type of this node. */
+static bool isbranch(const node_t *t)
+{
+ return t->i & TFLAG_BRANCH;
+}
+
+static tkey_t *tkey(const node_t *t)
+{
+ assert(!isbranch(t));
+ return (tkey_t *)(uintptr_t)(t->i & TMASK_LEAF);
+}
+
+static trie_val_t *tvalp(node_t *t)
+{
+ assert(!isbranch(t));
+ return &t->p;
+}
+
+/*! \brief Given a branch node, return the index of the corresponding nibble in the key. */
+static index_t branch_index(const node_t *t)
+{
+ assert(isbranch(t));
+ return (t->i & TMASK_INDEX) >> TSHIFT_INDEX;
+}
+
+static bitmap_t branch_bmp(const node_t *t)
+{
+ assert(isbranch(t));
+ return (t->i & TMASK_BMP);
+}
+
+/*!
+ * \brief Count the number of set bits.
+ *
+ * \TODO This implementation may be relatively slow on some HW.
+ */
+static uint branch_weight(const node_t *t)
+{
+ assert(isbranch(t));
+ uint n = __builtin_popcount(t->i & TMASK_BMP);
+ assert(n > 1 && n <= TWIDTH_BMP);
+ return n;
+}
+
+/*! \brief Compute offset of an existing child in a branch node. */
+static uint twigoff(const node_t *t, bitmap_t bit)
+{
+ assert(isbranch(t));
+ assert(__builtin_popcount(bit) == 1);
+ return __builtin_popcount(t->i & TMASK_BMP & (bit - 1));
+}
+
+/*! \brief Extract a nibble from a key and turn it into a bitmask. */
+static bitmap_t keybit(index_t ni, const trie_key_t *key, uint32_t len)
+{
+ index_t bytei = ni >> 1;
+
+ if (bytei >= len)
+ return BMP_NOBYTE;
+
+ uint8_t ki = (uint8_t)key[bytei];
+ uint nibble = (ni & 1) ? (ki & 0xf) : (ki >> 4);
+
+ // skip one for NOBYTE nibbles after the end of the key
+ return BIG1 << (nibble + 1 + TSHIFT_BMP);
+}
+
+/*! \brief Extract a nibble from a key and turn it into a bitmask. */
+static bitmap_t twigbit(const node_t *t, const trie_key_t *key, uint32_t len)
+{
+ assert(isbranch(t));
+ return keybit(branch_index(t), key, len);
+}
+
+/*! \brief Test if a branch node has a child indicated by a bitmask. */
+static bool hastwig(const node_t *t, bitmap_t bit)
+{
+ assert(isbranch(t));
+ assert((bit & ~TMASK_BMP) == 0);
+ assert(__builtin_popcount(bit) == 1);
+ return t->i & bit;
+}
+
+/*! \brief Get pointer to packed array of child nodes. */
+static node_t* twigs(node_t *t)
+{
+ assert(isbranch(t));
+ return t->p;
+}
+
+/*! \brief Get pointer to a particular child of a branch node. */
+static node_t* twig(node_t *t, uint i)
+{
+ assert(i < branch_weight(t));
+ return twigs(t) + i;
+}
+
+/*! \brief Get twig number of a child node TODO: better description. */
+static uint twig_number(node_t *child, node_t *parent)
+{
+ // twig array index using pointer arithmetic
+ ptrdiff_t num = child - twigs(parent);
+ assert(num >= 0 && num < branch_weight(parent));
+ return (uint)num;
+}
+
+/*! \brief Simple string comparator. */
+static int key_cmp(const trie_key_t *k1, uint32_t k1_len,
+ const trie_key_t *k2, uint32_t k2_len)
+{
+ int ret = memcmp(k1, k2, MIN(k1_len, k2_len));
+ if (ret != 0) {
+ return ret;
+ }
+
+ /* Key string is equal, compare lengths. */
+ if (k1_len == k2_len) {
+ return 0;
+ } else if (k1_len < k2_len) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
+trie_t* trie_create(knot_mm_t *mm)
+{
+ trie_t *trie = mm_alloc(mm, sizeof(trie_t));
+ if (trie != NULL) {
+ trie->root = empty_root();
+ trie->weight = 0;
+ if (mm != NULL)
+ trie->mm = *mm;
+ else
+ mm_ctx_init(&trie->mm);
+ }
+ return trie;
+}
+
+/*! \brief Free anything under the trie node, except for the passed pointer itself. */
+static void clear_trie(node_t *trie, knot_mm_t *mm)
+{
+ if (!isbranch(trie)) {
+ mm_free(mm, tkey(trie));
+ } else {
+ uint n = branch_weight(trie);
+ for (uint i = 0; i < n; ++i)
+ clear_trie(twig(trie, i), mm);
+ mm_free(mm, twigs(trie));
+ }
+}
+
+void trie_free(trie_t *tbl)
+{
+ if (tbl == NULL)
+ return;
+ if (tbl->weight)
+ clear_trie(&tbl->root, &tbl->mm);
+ mm_free(&tbl->mm, tbl);
+}
+
+void trie_clear(trie_t *tbl)
+{
+ assert(tbl);
+ if (!tbl->weight)
+ return;
+ clear_trie(&tbl->root, &tbl->mm);
+ tbl->root = empty_root();
+ tbl->weight = 0;
+}
+
+static bool dup_trie(node_t *copy, const node_t *orig, trie_dup_cb dup_cb, knot_mm_t *mm)
+{
+ if (isbranch(orig)) {
+ uint n = branch_weight(orig);
+ node_t *cotw = mm_alloc(mm, n * sizeof(*cotw));
+ if (cotw == NULL) {
+ return NULL;
+ }
+ const node_t *ortw = twigs((node_t *)orig);
+ for (uint i = 0; i < n; ++i) {
+ if (!dup_trie(cotw + i, ortw + i, dup_cb, mm)) {
+ while (i-- > 0) {
+ clear_trie(cotw + i, mm);
+ }
+ mm_free(mm, cotw);
+ return false;
+ }
+ }
+ *copy = mkbranch(branch_index(orig), branch_bmp(orig), cotw);
+ } else {
+ tkey_t *key = tkey(orig);
+ if (mkleaf(copy, key->chars, key->len, mm) != KNOT_EOK) {
+ return false;
+ }
+ if ((copy->p = dup_cb(orig->p, mm)) == NULL) {
+ mm_free(mm, tkey(copy));
+ return false;
+ }
+ }
+ return true;
+}
+
+trie_t* trie_dup(const trie_t *orig, trie_dup_cb dup_cb, knot_mm_t *mm)
+{
+ if (orig == NULL) {
+ return NULL;
+ }
+ trie_t *copy = mm_alloc(mm, sizeof(*copy));
+ if (copy == NULL) {
+ return NULL;
+ }
+ copy->weight = orig->weight;
+ if (mm != NULL) {
+ copy->mm = *mm;
+ } else {
+ mm_ctx_init(©->mm);
+ }
+ if (copy->weight) {
+ if (!dup_trie(©->root, &orig->root, dup_cb, mm)) {
+ mm_free(mm, copy);
+ return NULL;
+ }
+ }
+ return copy;
+}
+
+size_t trie_weight(const trie_t *tbl)
+{
+ assert(tbl);
+ return tbl->weight;
+}
+
+trie_val_t* trie_get_try(trie_t *tbl, const trie_key_t *key, uint32_t len)
+{
+ assert(tbl);
+ if (!tbl->weight)
+ return NULL;
+ node_t *t = &tbl->root;
+ while (isbranch(t)) {
+ __builtin_prefetch(twigs(t));
+ bitmap_t b = twigbit(t, key, len);
+ if (!hastwig(t, b))
+ return NULL;
+ t = twig(t, twigoff(t, b));
+ }
+ tkey_t *lkey = tkey(t);
+ if (key_cmp(key, len, lkey->chars, lkey->len) != 0)
+ return NULL;
+ return tvalp(t);
+}
+
+/* Optimization: the approach isn't ideal, as e.g. walking through the prefix
+ * is duplicated and we explicitly construct the wildcard key. Still, it's close
+ * to optimum which would be significantly more complicated and error-prone to write. */
+trie_val_t* trie_get_try_wildcard(trie_t *tbl, const trie_key_t *key, uint32_t len)
+{
+ assert(tbl);
+ if (!tbl->weight)
+ return NULL;
+ // Find leaf sharing the longest common prefix; see ns_find_branch() for explanation.
+ node_t *t = &tbl->root;
+ while (isbranch(t)) {
+ __builtin_prefetch(twigs(t));
+ bitmap_t b = twigbit(t, key, len);
+ uint i = hastwig(t, b) ? twigoff(t, b) : 0;
+ t = twig(t, i);
+ }
+ const tkey_t * const lcp_key = tkey(t);
+
+ // Find the last matching zero byte or -1 (source of synthesis)
+ int i_lmz = -1;
+ for (int i = 0; i < len && i < lcp_key->len && key[i] == lcp_key->chars[i]; ++i) {
+ if (key[i] == '\0' && i < len - 1) // do not count the terminating zero
+ i_lmz = i;
+ // Shortcut: we may have found an exact match.
+ if (i == len - 1 && len == lcp_key->len)
+ return tvalp(t);
+ }
+ if (len == 0) // The empty name needs separate handling.
+ return lcp_key->len == 0 ? tvalp(t) : NULL;
+
+ // Construct the key of the wildcard we need and look it up.
+ const int wild_len = i_lmz + 3;
+ uint8_t wild_key[wild_len];
+ memcpy(wild_key, key, wild_len - 2);
+ wild_key[wild_len - 2] = '*';
+ wild_key[wild_len - 1] = '\0'; // LF is always 0-terminated ATM
+ return trie_get_try(tbl, wild_key, wild_len);
+}
+
+/*! \brief Delete leaf t with parent p; b is the bit for t under p.
+ * Optionally return the deleted value via val. The function can't fail. */
+static void del_found(trie_t *tbl, node_t *t, node_t *p, bitmap_t b, trie_val_t *val)
+{
+ assert(!tkey(t)->cow);
+ mm_free(&tbl->mm, tkey(t));
+ if (val != NULL)
+ *val = *tvalp(t); // we return trie_val_t directly when deleting
+ --tbl->weight;
+ if (unlikely(!p)) { // whole trie was a single leaf
+ assert(tbl->weight == 0);
+ tbl->root = empty_root();
+ return;
+ }
+ // remove leaf t as child of p
+ node_t *tp = twigs(p);
+ uint ci = twig_number(t, p);
+ uint cc = branch_weight(p); // child count
+
+ if (cc == 2) {
+ // collapse binary node p: move the other child to the parent
+ *p = tp[1 - ci];
+ mm_free(&tbl->mm, tp);
+ return;
+ }
+ memmove(tp + ci, tp + ci + 1, sizeof(node_t) * (cc - ci - 1));
+ p->i &= ~b;
+ node_t *newt = mm_realloc(&tbl->mm, tp, sizeof(node_t) * (cc - 1),
+ sizeof(node_t) * cc);
+ if (likely(newt != NULL))
+ p->p = newt;
+ // We can ignore mm_realloc failure because an oversized twig
+ // array is OK - only beware that next time the prev_size
+ // passed to mm_realloc will not be correct; TODO?
+}
+
+int trie_del(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t *val)
+{
+ assert(tbl);
+ if (!tbl->weight)
+ return KNOT_ENOENT;
+ node_t *t = &tbl->root; // current and parent node
+ node_t *p = NULL;
+ bitmap_t b = 0;
+ while (isbranch(t)) {
+ __builtin_prefetch(twigs(t));
+ b = twigbit(t, key, len);
+ if (!hastwig(t, b))
+ return KNOT_ENOENT;
+ p = t;
+ t = twig(t, twigoff(t, b));
+ }
+ tkey_t *lkey = tkey(t);
+ if (key_cmp(key, len, lkey->chars, lkey->len) != 0)
+ return KNOT_ENOENT;
+ del_found(tbl, t, p, b, val);
+ return KNOT_EOK;
+}
+
+/*!
+ * \brief Stack of nodes, storing a path down a trie.
+ *
+ * The structure also serves directly as the public trie_it_t type,
+ * in which case it always points to the current leaf, unless we've finished
+ * (i.e. it->len == 0).
+ * stack[0] is always a valid pointer to the root -> ns_gettrie()
+ */
+typedef struct trie_it {
+ node_t* *stack; /*!< The stack; malloc is used directly instead of mm. */
+ uint32_t len; /*!< Current length of the stack. */
+ uint32_t alen; /*!< Allocated/available length of the stack. */
+ /*! \brief Initial storage for \a stack; it should fit in most use cases. */
+ node_t* stack_init[250];
+} nstack_t;
+
+/*! \brief Create a node stack containing just the root (or empty). */
+static void ns_init(nstack_t *ns, trie_t *tbl)
+{
+ assert(tbl);
+ ns->stack = ns->stack_init;
+ ns->alen = sizeof(ns->stack_init) / sizeof(ns->stack_init[0]);
+ ns->stack[0] = &tbl->root;
+ ns->len = (tbl->weight > 0);
+}
+
+static inline trie_t * ns_gettrie(nstack_t *ns)
+{
+ assert(ns && ns->stack && ns->stack[0]);
+ return (struct trie *)ns->stack[0];
+}
+
+/*! \brief Free inside of the stack, i.e. not the passed pointer itself. */
+static void ns_cleanup(nstack_t *ns)
+{
+ assert(ns && ns->stack);
+ if (likely(ns->stack == ns->stack_init))
+ return;
+ free(ns->stack);
+ #ifndef NDEBUG
+ ns->stack = NULL;
+ ns->alen = 0;
+ #endif
+}
+
+/*! \brief Allocate more space for the stack. */
+static int ns_longer_alloc(nstack_t *ns)
+{
+ ns->alen *= 2;
+ size_t new_size = ns->alen * sizeof(node_t *);
+ node_t **st;
+ if (ns->stack == ns->stack_init) {
+ st = malloc(new_size);
+ if (st != NULL)
+ memcpy(st, ns->stack, ns->len * sizeof(node_t *));
+ } else {
+ st = realloc(ns->stack, new_size);
+ }
+ if (st == NULL)
+ return KNOT_ENOMEM;
+ ns->stack = st;
+ return KNOT_EOK;
+}
+
+/*! \brief Ensure the node stack can be extended by one. */
+static inline int ns_longer(nstack_t *ns)
+{
+ // get a longer stack if needed
+ if (likely(ns->len < ns->alen))
+ return KNOT_EOK;
+ return ns_longer_alloc(ns); // hand-split the part suitable for inlining
+}
+
+/*!
+ * \brief Find the "branching point" as if searching for a key.
+ *
+ * The whole path to the point is kept on the passed stack;
+ * always at least the root will remain on the top of it.
+ * Beware: the precise semantics of this function is rather tricky.
+ * The top of the stack will contain: the corresponding leaf if exact
+ * match is found; or the immediate node below a
+ * branching-point-on-edge or the branching-point itself.
+ *
+ * \param idiff Set the index of first differing nibble, or TMAX_INDEX for an exact match
+ * \param tbit Set the bit of the closest leaf's nibble at index idiff
+ * \param kbit Set the bit of the key's nibble at index idiff
+ *
+ * \return KNOT_EOK or KNOT_ENOMEM.
+ */
+static int ns_find_branch(nstack_t *ns, const trie_key_t *key, uint32_t len,
+ index_t *idiff, bitmap_t *tbit, bitmap_t *kbit)
+{
+ assert(ns && ns->len && idiff);
+ // First find some leaf with longest matching prefix.
+ while (isbranch(ns->stack[ns->len - 1])) {
+ ERR_RETURN(ns_longer(ns));
+ node_t *t = ns->stack[ns->len - 1];
+ __builtin_prefetch(twigs(t));
+ bitmap_t b = twigbit(t, key, len);
+ // Even if our key is missing from this branch we need to
+ // keep iterating down to a leaf. It doesn't matter which
+ // twig we choose since the keys are all the same up to this
+ // index. Note that blindly using twigoff(t, b) can cause
+ // an out-of-bounds index if it equals twigmax(t).
+ uint i = hastwig(t, b) ? twigoff(t, b) : 0;
+ ns->stack[ns->len++] = twig(t, i);
+ }
+ tkey_t *lkey = tkey(ns->stack[ns->len-1]);
+ // Find index of the first char that differs.
+ size_t bytei = 0;
+ uint32_t klen = lkey->len;
+ for (bytei = 0; bytei < MIN(len,klen); bytei++) {
+ if (key[bytei] != lkey->chars[bytei])
+ break;
+ }
+ // Find which half-byte has matched.
+ index_t index = bytei << 1;
+ if (bytei == len && len == lkey->len) { // found equivalent key
+ index = TMAX_INDEX;
+ goto success;
+ }
+ if (likely(bytei < MIN(len,klen))) {
+ uint8_t k2 = (uint8_t)lkey->chars[bytei];
+ uint8_t k1 = (uint8_t)key[bytei];
+ if (((k1 ^ k2) & 0xf0) == 0)
+ index += 1;
+ }
+ // now go up the trie from the current leaf
+ node_t *t;
+ do {
+ if (unlikely(ns->len == 1))
+ goto success; // only the root stays on the stack
+ t = ns->stack[ns->len - 2];
+ if (branch_index(t) < index)
+ goto success;
+ --ns->len;
+ } while (true);
+success:
+ #ifndef NDEBUG // invariants on successful return
+ assert(ns->len);
+ if (isbranch(ns->stack[ns->len - 1])) {
+ t = ns->stack[ns->len - 1];
+ assert(branch_index(t) >= index);
+ }
+ if (ns->len > 1) {
+ t = ns->stack[ns->len - 2];
+ assert(branch_index(t) < index || index == TMAX_INDEX);
+ }
+ #endif
+ *idiff = index;
+ *tbit = keybit(index, lkey->chars, lkey->len);
+ *kbit = keybit(index, key, len);
+ return KNOT_EOK;
+}
+
+/*!
+ * \brief Advance the node stack to the last leaf in the subtree.
+ *
+ * \return KNOT_EOK or KNOT_ENOMEM.
+ */
+static int ns_last_leaf(nstack_t *ns)
+{
+ assert(ns);
+ do {
+ ERR_RETURN(ns_longer(ns));
+ node_t *t = ns->stack[ns->len - 1];
+ if (!isbranch(t))
+ return KNOT_EOK;
+ uint lasti = branch_weight(t) - 1;
+ ns->stack[ns->len++] = twig(t, lasti);
+ } while (true);
+}
+
+/*!
+ * \brief Advance the node stack to the first leaf in the subtree.
+ *
+ * \return KNOT_EOK or KNOT_ENOMEM.
+ */
+static int ns_first_leaf(nstack_t *ns)
+{
+ assert(ns && ns->len);
+ do {
+ ERR_RETURN(ns_longer(ns));
+ node_t *t = ns->stack[ns->len - 1];
+ if (!isbranch(t))
+ return KNOT_EOK;
+ ns->stack[ns->len++] = twig(t, 0);
+ } while (true);
+}
+
+/*!
+ * \brief Advance the node stack to the leaf that is previous to the current node.
+ *
+ * \note Prefix leaf under the current node DOES count (if present; perhaps questionable).
+ * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM.
+ */
+static int ns_prev_leaf(nstack_t *ns)
+{
+ assert(ns && ns->len > 0);
+
+ node_t *t = ns->stack[ns->len - 1];
+ // Beware: BMP_NOBYTE child is ordered *before* its parent.
+ if (isbranch(t) && hastwig(t, BMP_NOBYTE)) {
+ ERR_RETURN(ns_longer(ns));
+ ns->stack[ns->len++] = twig(t, 0);
+ return KNOT_EOK;
+ }
+
+ for (; ns->len >= 2; --ns->len) {
+ t = ns->stack[ns->len - 1];
+ node_t *p = ns->stack[ns->len - 2];
+ uint ci = twig_number(t, p);
+ if (ci == 0) // we've got to go up again
+ continue;
+ // t isn't the first child -> go down the previous one
+ ns->stack[ns->len - 1] = twig(p, ci - 1);
+ return ns_last_leaf(ns);
+ }
+ return KNOT_ENOENT; // root without empty key has no previous leaf
+}
+
+/*!
+ * \brief Advance the node stack to the leaf that is successor to the current node.
+ *
+ * \param skip_prefixed skip any nodes whose key is a prefix of the current one.
+ * If false, prefix leaf or anything else under the current node DOES count.
+ * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM.
+ */
+static int ns_next_leaf(nstack_t *ns, const bool skip_pefixed)
+{
+ assert(ns && ns->len > 0);
+
+ node_t *t = ns->stack[ns->len - 1];
+ if (!skip_pefixed && isbranch(t))
+ return ns_first_leaf(ns);
+ for (; ns->len >= 2; --ns->len) {
+ t = ns->stack[ns->len - 1];
+ node_t *p = ns->stack[ns->len - 2];
+ uint ci = twig_number(t, p);
+ if (skip_pefixed && ci == 0 && hastwig(t, BMP_NOBYTE)) {
+ // Keys in the subtree of p are suffixes of the key of t,
+ // so we've got to go one level higher
+ // (this can't happen more than once)
+ continue;
+ }
+ uint cc = branch_weight(p);
+ assert(ci + 1 <= cc);
+ if (ci + 1 == cc) {
+ // t is the last child of p, so we need to keep climbing
+ continue;
+ }
+ // go down the next child of p
+ ns->stack[ns->len - 1] = twig(p, ci + 1);
+ return ns_first_leaf(ns);
+ }
+ return KNOT_ENOENT; // not found, as no more parent is available
+}
+
+/*! \brief Advance the node stack to leaf with longest prefix of the current key. */
+static int ns_prefix(nstack_t *ns)
+{
+ assert(ns && ns->len > 0);
+ const node_t *start = ns->stack[ns->len - 1];
+ // Walk up the trie until we find a BMP_NOBYTE child.
+ while (--ns->len > 0) {
+ node_t *p = ns->stack[ns->len - 1];
+ if (!hastwig(p, BMP_NOBYTE))
+ continue;
+ node_t *end = twig(p, 0);
+ // In case we started in a BMP_NOBYTE leaf, the first step up
+ // did NOT shorten the key and we would get back into the same
+ // node again.
+ if (end == start)
+ continue;
+ ns->stack[ns->len++] = end;
+ return KNOT_EOK;
+ }
+ return KNOT_ENOENT; // not found, as no more parent is available
+}
+
+/*! \brief less-or-equal search.
+ *
+ * \return KNOT_EOK for exact match, 1 for previous, KNOT_ENOENT for not-found,
+ * or KNOT_E*.
+ */
+static int ns_get_leq(nstack_t *ns, const trie_key_t *key, uint32_t len)
+{
+ // First find the key with longest-matching prefix
+ index_t idiff;
+ bitmap_t tbit, kbit;
+ ERR_RETURN(ns_find_branch(ns, key, len, &idiff, &tbit, &kbit));
+ node_t *t = ns->stack[ns->len - 1];
+ if (idiff == TMAX_INDEX) // found exact match
+ return KNOT_EOK;
+ // Get t: the last node on matching path
+ bitmap_t b;
+ if (isbranch(t) && branch_index(t) == idiff) {
+ // t is OK
+ b = kbit;
+ } else {
+ // the top of the stack was the first unmatched node -> step up
+ if (ns->len == 1) {
+ // root was unmatched already
+ if (kbit < tbit)
+ return KNOT_ENOENT;
+ ERR_RETURN(ns_last_leaf(ns));
+ return 1;
+ }
+ --ns->len;
+ t = ns->stack[ns->len - 1];
+ b = twigbit(t, key, len);
+ }
+ // Now we re-do the first "non-matching" step in the trie
+ // but try the previous child if key was less (it may not exist)
+ int i = hastwig(t, b)
+ ? (int)twigoff(t, b) - (kbit < tbit)
+ : (int)twigoff(t, b) - 1 /* twigoff returns successor when !hastwig */;
+ if (i >= 0) {
+ ERR_RETURN(ns_longer(ns));
+ ns->stack[ns->len++] = twig(t, i);
+ ERR_RETURN(ns_last_leaf(ns));
+ } else {
+ ERR_RETURN(ns_prev_leaf(ns));
+ }
+ return 1;
+}
+
+int trie_get_leq(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t **val)
+{
+ assert(tbl && val);
+ if (tbl->weight == 0) {
+ if (val) *val = NULL;
+ return KNOT_ENOENT;
+ }
+ // We try to do without malloc.
+ nstack_t ns_local;
+ ns_init(&ns_local, tbl);
+ nstack_t *ns = &ns_local;
+
+ int ret = ns_get_leq(ns, key, len);
+ if (ret == KNOT_EOK || ret == 1) {
+ assert(!isbranch(ns->stack[ns->len - 1]));
+ if (val) *val = tvalp(ns->stack[ns->len - 1]);
+ } else {
+ if (val) *val = NULL;
+ }
+ ns_cleanup(ns);
+ return ret;
+}
+
+int trie_it_get_leq(trie_it_t *it, const trie_key_t *key, uint32_t len)
+{
+ assert(it && it->stack[0] && it->alen);
+ const trie_t *tbl = ns_gettrie(it);
+ if (tbl->weight == 0) {
+ it->len = 0;
+ return KNOT_ENOENT;
+ }
+ it->len = 1;
+ int ret = ns_get_leq(it, key, len);
+ if (ret == KNOT_EOK || ret == 1) {
+ assert(trie_it_key(it, NULL));
+ } else {
+ it->len = 0;
+ }
+ return ret;
+}
+
+/* see below */
+static int cow_pushdown(trie_cow_t *cow, nstack_t *ns);
+
+/*! \brief implementation of trie_get_ins() and trie_get_cow() */
+static trie_val_t* cow_get_ins(trie_cow_t *cow, trie_t *tbl,
+ const trie_key_t *key, uint32_t len)
+{
+ assert(tbl);
+ // First leaf in an empty tbl?
+ if (unlikely(!tbl->weight)) {
+ if (unlikely(mkleaf(&tbl->root, key, len, &tbl->mm)))
+ return NULL;
+ ++tbl->weight;
+ return tvalp(&tbl->root);
+ }
+ { // Intentionally un-indented; until end of function, to bound cleanup attr.
+ // Find the branching-point
+ __attribute__((cleanup(ns_cleanup)))
+ nstack_t ns_local;
+ ns_init(&ns_local, tbl);
+ nstack_t *ns = &ns_local;
+ index_t idiff;
+ bitmap_t tbit, kbit;
+ if (unlikely(ns_find_branch(ns, key, len, &idiff, &tbit, &kbit)))
+ return NULL;
+ if (unlikely(cow && cow_pushdown(cow, ns) != KNOT_EOK))
+ return NULL;
+ node_t *t = ns->stack[ns->len - 1];
+ if (idiff == TMAX_INDEX) // the same key was already present
+ return tvalp(t);
+ node_t leaf, *leafp;
+ if (unlikely(mkleaf(&leaf, key, len, &tbl->mm)))
+ return NULL;
+
+ if (isbranch(t) && branch_index(t) == idiff) {
+ // The node t needs a new leaf child.
+ assert(!hastwig(t, kbit));
+ // new child position and original child count
+ uint s = twigoff(t, kbit);
+ uint m = branch_weight(t);
+ node_t *nt = mm_realloc(&tbl->mm, twigs(t),
+ sizeof(node_t) * (m + 1), sizeof(node_t) * m);
+ if (unlikely(!nt))
+ goto err_leaf;
+ memmove(nt + s + 1, nt + s, sizeof(node_t) * (m - s));
+ leafp = nt + s;
+ *t = mkbranch(idiff, branch_bmp(t) | kbit, nt);
+ } else {
+ // We need to insert a new binary branch with leaf at *t.
+ // Note: it works the same for the case where we insert above root t.
+ #ifndef NDEBUG
+ if (ns->len > 1) {
+ node_t *pt = ns->stack[ns->len - 2];
+ assert(hastwig(pt, twigbit(pt, key, len)));
+ }
+ #endif
+ node_t *nt = mm_alloc(&tbl->mm, sizeof(node_t) * 2);
+ if (unlikely(!nt))
+ goto err_leaf;
+ node_t t2 = *t; // Save before overwriting t.
+ *t = mkbranch(idiff, tbit | kbit, nt);
+ *twig(t, twigoff(t, tbit)) = t2;
+ leafp = twig(t, twigoff(t, kbit));
+ };
+ *leafp = leaf;
+ ++tbl->weight;
+ return tvalp(leafp);
+err_leaf:
+ mm_free(&tbl->mm, tkey(&leaf));
+ return NULL;
+ }
+}
+
+trie_val_t* trie_get_ins(trie_t *tbl, const trie_key_t *key, uint32_t len)
+{
+ return cow_get_ins(NULL, tbl, key, len);
+}
+
+/*! \brief Apply a function to every trie_val_t*, in order; a recursive solution. */
+static int apply_nodes(node_t *t, int (*f)(trie_val_t *, void *), void *d)
+{
+ assert(t);
+ if (!isbranch(t))
+ return f(tvalp(t), d);
+ uint n = branch_weight(t);
+ for (uint i = 0; i < n; ++i)
+ ERR_RETURN(apply_nodes(twig(t, i), f, d));
+ return KNOT_EOK;
+}
+
+int trie_apply(trie_t *tbl, int (*f)(trie_val_t *, void *), void *d)
+{
+ assert(tbl && f);
+ if (!tbl->weight)
+ return KNOT_EOK;
+ return apply_nodes(&tbl->root, f, d);
+}
+
+/* These are all thin wrappers around static Tns* functions. */
+trie_it_t* trie_it_begin(trie_t *tbl)
+{
+ assert(tbl);
+ trie_it_t *it = malloc(sizeof(nstack_t));
+ if (!it)
+ return NULL;
+ ns_init(it, tbl);
+ if (it->len == 0) // empty tbl
+ return it;
+ if (ns_first_leaf(it)) {
+ ns_cleanup(it);
+ free(it);
+ return NULL;
+ }
+ return it;
+}
+
+bool trie_it_finished(trie_it_t *it)
+{
+ assert(it);
+ return it->len == 0;
+}
+
+void trie_it_free(trie_it_t *it)
+{
+ if (!it)
+ return;
+ ns_cleanup(it);
+ free(it);
+}
+
+trie_it_t *trie_it_clone(const trie_it_t *it)
+{
+ if (!it) // TODO: or should that be an assertion?
+ return NULL;
+ trie_it_t *it2 = malloc(sizeof(nstack_t));
+ if (!it2)
+ return NULL;
+ it2->len = it->len;
+ it2->alen = it->alen; // we _might_ change it in the rare malloc case, but...
+ if (likely(it->stack == it->stack_init)) {
+ it2->stack = it2->stack_init;
+ assert(it->alen == sizeof(it->stack_init) / sizeof(it->stack_init[0]));
+ } else {
+ it2->stack = malloc(it2->alen * sizeof(it2->stack[0]));
+ if (!it2->stack) {
+ free(it2);
+ return NULL;
+ }
+ }
+ memcpy(it2->stack, it->stack, it->len * sizeof(it->stack[0]));
+ return it2;
+}
+
+const trie_key_t* trie_it_key(trie_it_t *it, size_t *len)
+{
+ assert(it && it->len);
+ node_t *t = it->stack[it->len - 1];
+ assert(!isbranch(t));
+ tkey_t *key = tkey(t);
+ if (len)
+ *len = key->len;
+ return key->chars;
+}
+
+trie_val_t* trie_it_val(trie_it_t *it)
+{
+ assert(it && it->len);
+ node_t *t = it->stack[it->len - 1];
+ assert(!isbranch(t));
+ return tvalp(t);
+}
+
+void trie_it_next(trie_it_t *it)
+{
+ assert(it && it->len);
+ if (ns_next_leaf(it, false) != KNOT_EOK)
+ it->len = 0;
+}
+
+void trie_it_next_loop(trie_it_t *it)
+{
+ assert(it && it->len);
+ int ret = ns_next_leaf(it, false);
+ if (ret == KNOT_ENOENT) {
+ it->len = 1;
+ ret = ns_first_leaf(it);
+ }
+ if (ret)
+ it->len = 0;
+}
+
+void trie_it_next_nosuffix(trie_it_t *it)
+{
+ assert(it && it->len);
+ if (ns_next_leaf(it, true) != KNOT_EOK)
+ it->len = 0;
+}
+
+void trie_it_prev(trie_it_t *it)
+{
+ assert(it && it->len);
+ if (ns_prev_leaf(it) != KNOT_EOK)
+ it->len = 0;
+}
+
+void trie_it_prev_loop(trie_it_t *it)
+{
+ assert(it && it->len);
+ int ret = ns_prev_leaf(it);
+ if (ret == KNOT_ENOENT) {
+ it->len = 1;
+ ret = ns_last_leaf(it);
+ }
+ if (ret)
+ it->len = 0;
+}
+
+void trie_it_parent(trie_it_t *it)
+{
+ assert(it && it->len);
+ if (ns_prefix(it))
+ it->len = 0;
+}
+
+void trie_it_del(trie_it_t *it)
+{
+ assert(it && it->len);
+ if (it->len == 0)
+ return;
+ node_t *t = it->stack[it->len - 1];
+ assert(!isbranch(t));
+ bitmap_t b; // del_found() needs to know which bit to zero in the bitmap
+ node_t *p;
+ if (it->len == 1) { // deleting the root
+ p = NULL;
+ b = 0; // unused
+ } else {
+ p = it->stack[it->len - 2];
+ assert(isbranch(p));
+ size_t len;
+ const trie_key_t *key = trie_it_key(it, &len);
+ b = twigbit(p, key, len);
+ }
+ // We could trie_it_{next,prev,...}(it) now, in case we wanted that semantics.
+ it->len = 0;
+ del_found(ns_gettrie(it), t, p, b, NULL);
+}
+
+
+/*!\file
+ *
+ * \section About copy-on-write
+ *
+ * In these notes I'll use the term "object" to refer to either the
+ * twig array of a branch, or the application's data that is referred
+ * to by a leaf's trie_val_t pointer. Note that for COW we don't care
+ * about trie node_t structs themselves, but the objects that they
+ * point to.
+ *
+ * \subsection COW states
+ *
+ * During a COW transaction an object can be in one of three states:
+ * shared, only in the old trie, or only in the new trie. When a
+ * transaction is rolled back, the only-new objects are freed; when a
+ * transaction is committed the new trie takes the place of the old
+ * one and only-old objects are freed.
+ *
+ * \subsection branch marks and regions
+ *
+ * A branch object can be marked by setting the COW flag in the first
+ * element of its twig array. Marked branches partition the trie into
+ * regions; an object's state depends on its region.
+ *
+ * The unmarked branch objects between a trie's root and the marked
+ * branches (excluding the marked branches themselves) is exclusively
+ * owned: either old-only (if you started from the old root) or
+ * new-only (if you started from the new root).
+ *
+ * Marked branch objects, and all objects reachable from marked branch
+ * objects, are in the shared region accessible from both old and new
+ * roots. All branch objects below a marked branch must be unmarked.
+ * (That is, there is at most one marked branch object on any path
+ * from the root of a trie.)
+ *
+ * Branch nodes in the new-only region can be modified in place, in
+ * the same way as an original qp trie. Branch nodes in the old-only
+ * or shared regions must not be modified.
+ *
+ * \subsection app object states
+ *
+ * The app objects reachable from the new-only and old-only regions
+ * explicitly record their state in a way determined by the
+ * application. (These app objects are reachable from the old and new
+ * roots by traversing only unmarked branch objects.)
+ *
+ * The app objects reachable from marked branch objects are implicitly
+ * shared, but their state field has an indeterminate value. If an app
+ * object was previously touched by a rolled-back transaction it may
+ * be marked shared or old-only; if it was previously touched by a
+ * committed transaction it may be marked shared or new-only.
+ *
+ * \subsection key states
+ *
+ * The memory allocated for tkey_t objects also needs to track its
+ * sharing state. They have a "cow" flag to mark when they are shared.
+ * Keys are relatively lazily copied (to make them exclusive) when
+ * their leaf node is touched by a COW mutation.
+ *
+ * [An alternative technique might be to copy them more eagerly, in
+ * cow_pushdown(), which would avoid the need for a flag bit at the
+ * cost of more allocator churn in a transaction.]
+ *
+ * \subsection outside COW
+ *
+ * When a COW transaction is not in progress, there are no marked
+ * branch objects, so everything is exclusively owned. When a COW
+ * transaction is finished (committed or rolled back), the branch
+ * marks are removed. Since they are in the shared region, this branch
+ * cleanup is visible to both old and new tries.
+ *
+ * However the state of app objects is not clean between COW
+ * transactions. When a COW transaction is committed, we traverse the
+ * old-only region to find old-only app objects that should be freed
+ * (and vice versa for rollback). In general, there will be app
+ * objects that are only reachable from the new-only region, and that
+ * have a mixture of shared and new states.
+ */
+
+/*! \brief Trie copy-on-write state */
+struct trie_cow {
+ trie_t *old;
+ trie_t *new;
+ trie_cb *mark_shared;
+ void *d;
+};
+
+/*! \brief is this a marked branch object */
+static bool cow_marked(node_t *t)
+{
+ return isbranch(t) && (twigs(t)->i & TFLAG_COW);
+}
+
+/*! \brief is this a leaf with a marked key */
+static bool cow_key(node_t *t)
+{
+ return !isbranch(t) && tkey(t)->cow;
+}
+
+/*! \brief remove mark from a branch object */
+static void clear_cow(node_t *t)
+{
+ assert(isbranch(t));
+ twigs(t)->i &= ~TFLAG_COW;
+}
+
+/*! \brief mark a node as shared
+ *
+ * For branches this marks the twig array (in COW terminology, the
+ * branch object); for leaves it uses the callback to mark the app
+ * object.
+ */
+static void mark_cow(trie_cow_t *cow, node_t *t)
+{
+ if (isbranch(t)) {
+ node_t *object = twigs(t);
+ object->i |= TFLAG_COW;
+ } else {
+ tkey_t *lkey = tkey(t);
+ lkey->cow = 1;
+ if (cow->mark_shared != NULL) {
+ trie_val_t *valp = tvalp(t);
+ cow->mark_shared(*valp, lkey->chars, lkey->len, cow->d);
+ }
+ }
+}
+
+/*! \brief push exclusive COW region down one node */
+static int cow_pushdown_one(trie_cow_t *cow, node_t *t)
+{
+ uint cc = branch_weight(t);
+ node_t *nt = mm_alloc(&cow->new->mm, sizeof(node_t) * cc);
+ if (nt == NULL)
+ return KNOT_ENOMEM;
+ /* mark all the children */
+ for (uint ci = 0; ci < cc; ++ci)
+ mark_cow(cow, twig(t, ci));
+ /* this node must be unmarked in both old and new versions */
+ clear_cow(t);
+ t->p = memcpy(nt, twigs(t), sizeof(node_t) * cc);
+ return KNOT_EOK;
+}
+
+/*! \brief push exclusive COW region to cover a whole node stack */
+static int cow_pushdown(trie_cow_t *cow, nstack_t *ns)
+{
+ node_t *new_twigs = NULL;
+ node_t *old_twigs = NULL;
+ for (uint i = 0; i < ns->len; i++) {
+ /* if we did a pushdown on the previous iteration, we
+ need to update this stack entry so it points into
+ the parent's new twigs instead of the old ones */
+ if (new_twigs != old_twigs)
+ ns->stack[i] = new_twigs + (ns->stack[i] - old_twigs);
+ if (cow_marked(ns->stack[i])) {
+ old_twigs = twigs(ns->stack[i]);
+ if (cow_pushdown_one(cow, ns->stack[i]))
+ return KNOT_ENOMEM;
+ new_twigs = twigs(ns->stack[i]);
+ } else {
+ new_twigs = NULL;
+ old_twigs = NULL;
+ /* ensure key is exclusively owned */
+ if (cow_key(ns->stack[i])) {
+ node_t oleaf = *ns->stack[i];
+ tkey_t *okey = tkey(&oleaf);
+ if(mkleaf(ns->stack[i], okey->chars, okey->len,
+ &cow->new->mm))
+ return KNOT_ENOMEM;
+ ns->stack[i]->p = oleaf.p;
+ okey->cow = 0;
+ }
+ }
+ }
+ return KNOT_EOK;
+}
+
+trie_cow_t* trie_cow(trie_t *old, trie_cb *mark_shared, void *d)
+{
+ knot_mm_t *mm = &old->mm;
+ trie_t *new = mm_alloc(mm, sizeof(trie_t));
+ trie_cow_t *cow = mm_alloc(mm, sizeof(trie_cow_t));
+ if (new == NULL || cow == NULL) {
+ mm_free(mm, new);
+ mm_free(mm, cow);
+ return NULL;
+ }
+ new->mm = old->mm;
+ new->root = old->root;
+ new->weight = old->weight;
+ cow->old = old;
+ cow->new = new;
+ cow->mark_shared = mark_shared;
+ cow->d = d;
+ if (old->weight)
+ mark_cow(cow, &old->root);
+ return cow;
+}
+
+trie_t* trie_cow_new(trie_cow_t *cow)
+{
+ assert(cow != NULL);
+ return cow->new;
+}
+
+trie_val_t* trie_get_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len)
+{
+ return cow_get_ins(cow, cow->new, key, len);
+}
+
+int trie_del_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len, trie_val_t *val)
+{
+ trie_t *tbl = cow->new;
+ if (unlikely(!tbl->weight))
+ return KNOT_ENOENT;
+ { // Intentionally un-indented; until end of function, to bound cleanup attr.
+ // Find the branching-point
+ __attribute__((cleanup(ns_cleanup)))
+ nstack_t ns_local;
+ ns_init(&ns_local, tbl);
+ nstack_t *ns = &ns_local;
+ index_t idiff;
+ bitmap_t tbit, kbit;
+ ERR_RETURN(ns_find_branch(ns, key, len, &idiff, &tbit, &kbit));
+ if (idiff != TMAX_INDEX)
+ return KNOT_ENOENT;
+ ERR_RETURN(cow_pushdown(cow, ns));
+ node_t *t = ns->stack[ns->len - 1];
+ node_t *p = ns->len >= 2 ? ns->stack[ns->len - 2] : NULL;
+ del_found(tbl, t, p, p ? twigbit(p, key, len) : 0, val);
+ }
+ return KNOT_EOK;
+}
+
+/*! \brief clean up after a COW transaction, recursively */
+static void cow_cleanup(trie_cow_t *cow, node_t *t, trie_cb *cb, void *d)
+{
+ if (cow_marked(t)) {
+ // we have hit the shared region, so just reset the mark
+ clear_cow(t);
+ return;
+ } else if (isbranch(t)) {
+ // traverse and free the exclusive region
+ uint cc = branch_weight(t);
+ for (uint ci = 0; ci < cc; ++ci)
+ cow_cleanup(cow, twig(t, ci), cb, d);
+ mm_free(&cow->new->mm, twigs(t));
+ return;
+ } else {
+ // application must decide how to clean up its values
+ tkey_t *lkey = tkey(t);
+ if (cb != NULL) {
+ trie_val_t *valp = tvalp(t);
+ cb(*valp, lkey->chars, lkey->len, d);
+ }
+ // clean up exclusively-owned keys
+ if (lkey->cow)
+ lkey->cow = 0;
+ else
+ mm_free(&cow->new->mm, lkey);
+ return;
+ }
+}
+
+trie_t* trie_cow_commit(trie_cow_t *cow, trie_cb *cb, void *d)
+{
+ trie_t *ret = cow->new;
+ if (cow->old->weight)
+ cow_cleanup(cow, &cow->old->root, cb, d);
+ mm_free(&ret->mm, cow->old);
+ mm_free(&ret->mm, cow);
+ return ret;
+}
+
+trie_t* trie_cow_rollback(trie_cow_t *cow, trie_cb *cb, void *d)
+{
+ trie_t *ret = cow->old;
+ if (cow->new->weight)
+ cow_cleanup(cow, &cow->new->root, cb, d);
+ mm_free(&ret->mm, cow->new);
+ mm_free(&ret->mm, cow);
+ return ret;
+}
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ Copyright (C) 2018 Tony Finch <dot@dotat.at>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "libknot/mm_ctx.h"
+
+/*!
+ * \brief Native API of QP-tries:
+ *
+ * - keys are uint8_t strings, not necessarily zero-terminated,
+ * the structure copies the contents of the passed keys
+ * - values are void* pointers, typically you get an ephemeral pointer to it
+ * - key lengths are limited by 2^32-1 ATM
+ */
+
+/*! \brief Element value. */
+typedef void* trie_val_t;
+/*! \brief Key for indexing tries. Sign could be flipped easily. */
+typedef uint8_t trie_key_t;
+
+/*! \brief Opaque structure holding a QP-trie. */
+typedef struct trie trie_t;
+
+/*! \brief Opaque type for holding a QP-trie iterator. */
+typedef struct trie_it trie_it_t;
+
+/*! \brief Callback for cloning trie values. */
+typedef trie_val_t (*trie_dup_cb)(const trie_val_t val, knot_mm_t *mm);
+
+/*! \brief Callback for performing actions on a trie leaf
+ *
+ * Used during copy-on-write transactions
+ *
+ * \param val The value of the element to be altered
+ * \param key The key of the element to be altered
+ * \param len The length of key
+ * \param d Additional user data
+ */
+typedef void trie_cb(trie_val_t val, const trie_key_t *key, size_t len, void *d);
+
+/*! \brief Opaque type for holding the copy-on-write state for a QP-trie. */
+typedef struct trie_cow trie_cow_t;
+
+/*! \brief Create a trie instance. */
+trie_t* trie_create(knot_mm_t *mm);
+
+/*! \brief Free a trie instance. */
+void trie_free(trie_t *tbl);
+
+/*! \brief Clear a trie instance (make it empty). */
+void trie_clear(trie_t *tbl);
+
+/*! \brief Create a clone of existing trie. */
+trie_t* trie_dup(const trie_t *orig, trie_dup_cb dup_cb, knot_mm_t *mm);
+
+/*! \brief Return the number of keys in the trie. */
+size_t trie_weight(const trie_t *tbl);
+
+/*! \brief Search the trie, returning NULL on failure. */
+trie_val_t* trie_get_try(trie_t *tbl, const trie_key_t *key, uint32_t len);
+
+/*! \brief Search the trie including DNS wildcard semantics, returning NULL on failure.
+ *
+ * \note We assume the key is in knot_dname_lf() format, i.e. labels are ordered
+ * from root to leaf and separated by zero bytes (and no other zeros are allowed).
+ * \note Beware that DNS wildcard matching is not exactly what normal people would expect.
+ */
+trie_val_t* trie_get_try_wildcard(trie_t *tbl, const trie_key_t *key, uint32_t len);
+
+/*! \brief Search the trie, inserting NULL trie_val_t on failure. */
+trie_val_t* trie_get_ins(trie_t *tbl, const trie_key_t *key, uint32_t len);
+
+/*!
+ * \brief Search for less-or-equal element.
+ *
+ * \param tbl Trie.
+ * \param key Searched key.
+ * \param len Key length.
+ * \param val (optional) Value found; it will be set to NULL if not found or errored.
+ * \return KNOT_EOK for exact match, 1 for previous, KNOT_ENOENT for not-found,
+ * or KNOT_E*.
+ */
+int trie_get_leq(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t **val);
+
+/*!
+ * \brief Apply a function to every trie_val_t, in order.
+ *
+ * \return KNOT_EOK if success or KNOT_E* if error.
+ */
+int trie_apply(trie_t *tbl, int (*f)(trie_val_t *, void *), void *d);
+
+/*!
+ * \brief Remove an item, returning KNOT_EOK if succeeded or KNOT_ENOENT if not found.
+ *
+ * If val!=NULL and deletion succeeded, the deleted value is set.
+ */
+int trie_del(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t *val);
+
+
+/*! \brief Create a new iterator pointing to the first element (if any).
+ *
+ * trie_it_* functions deal with these iterators capable of walking and jumping
+ * over the trie. Note that any modification to key-set stored by the trie
+ * will in general invalidate all iterators and you will need to begin anew.
+ * (It won't be detected - you may end up reading freed memory, etc.)
+ */
+trie_it_t* trie_it_begin(trie_t *tbl);
+
+/*! \brief Test if the iterator has gone "past the end" (and points nowhere). */
+bool trie_it_finished(trie_it_t *it);
+
+/*! \brief Free any resources of the iterator. It's OK to call it on NULL. */
+void trie_it_free(trie_it_t *it);
+
+/*! \brief Copy the iterator. See the warning in trie_it_begin(). */
+trie_it_t *trie_it_clone(const trie_it_t *it);
+
+/*!
+ * \brief Return pointer to the key of the current element.
+ *
+ * \note The len is uint32_t internally but size_t is better for our usage
+ * as it is without an additional type conversion.
+ */
+const trie_key_t* trie_it_key(trie_it_t *it, size_t *len);
+
+/*! \brief Return pointer to the value of the current element (writable). */
+trie_val_t* trie_it_val(trie_it_t *it);
+
+/*!
+ * \brief Advance the iterator to the next element.
+ *
+ * Iteration is in ascending lexicographical order.
+ * In particular, the empty string would be considered as the very first.
+ *
+ * \TODO: in most iterator operations, ENOMEM is very unlikely
+ * but it leads to a _finished() iterator (silently).
+ * Perhaps the functions should simply return KNOT_E*
+ */
+void trie_it_next(trie_it_t *it);
+/*! \brief Advance the iterator to the previous element. See trie_it_next(). */
+void trie_it_prev(trie_it_t *it);
+
+/*! \brief Advance iterator to the next element, looping to first after last. */
+void trie_it_next_loop(trie_it_t *it);
+/*! \brief Advance iterator to the previous element, looping to last after first. */
+void trie_it_prev_loop(trie_it_t *it);
+
+/*! \brief Advance iterator to the next element while ignoring the subtree.
+ *
+ * \note Another formulation: skip keys that are prefixed by the current key.
+ * \TODO: name, maybe _noprefixed? The thing is that in the "subtree" meaning
+ * doesn't correspond to how the pointers go in the implementation,
+ * but we may not care much for implementation in the API...
+ */
+void trie_it_next_nosub(trie_it_t *it);
+
+/*! \brief Advance iterator to the longest prefix of the current key.
+ *
+ * \TODO: name, maybe _prefix? Arguments similar to _nosub vs. _noprefixed.
+ */
+void trie_it_parent(trie_it_t *it);
+
+/*! \brief trie_get_leq() but with an iterator. */
+int trie_it_get_leq(trie_it_t *it, const trie_key_t *key, uint32_t len);
+
+/*! \brief Remove the current element. The iterator will get trie_it_finished() */
+void trie_it_del(trie_it_t *it);
+
+
+/*! \brief Start a COW transaction
+ *
+ * A copy-on-write transaction starts by obtaining a write lock (in
+ * your application code) followed by a call to trie_cow(). This
+ * creates a shared clone of the trie and saves both old and new roots
+ * in the COW context.
+ *
+ * During the COW transaction, you call trie_cow_ins() or
+ * trie_cow_del() as necessary. These calls ensure that the relevant
+ * parts of the (new) trie are copied so that they can be modified
+ * freely.
+ *
+ * Your trie_val_t objects must be able to distinguish their
+ * reachability, either shared, or old-only, or new-only. Before a COW
+ * transaction the reachability of your objects is indeterminate.
+ * During a transaction, any trie_val_t objects that might be affected
+ * (because they are adjacent to a trie_get_cow() or trie_del_cow())
+ * are first marked as shared using the callback you pass to
+ * trie_cow().
+ *
+ * When the transaction is complete, to commit, call trie_cow_new() to
+ * get the new root, swap the old and new trie roots (e.g. with
+ * rcu_xchg_pointer()), wait for readers to finish with the old trie
+ * (e.g. using synchronize_rcu()), then call trie_cow_commit(). For a
+ * rollback, you can just call trie_cow_rollback() without waiting
+ * since that doesn't conflict with readers. After trie_cow_commit()
+ * or trie_cow_rollback() have finished, you can release your write
+ * lock.
+ *
+ * Concurrent reading of the old trie is allowed during a transaction
+ * provided that it is known when all readers have finished with the
+ * old version, e.g. using rcu_read_lock() and rcu_read_unlock().
+ * There must be only one write transaction at a time.
+ *
+ * \param old the old trie
+ * \param mark_shared callback to mark a leaf as shared (can be NULL)
+ * \param d extra data for the callback
+ * \return a pointer to a COW context,
+ * or NULL if there was a failure
+ */
+trie_cow_t* trie_cow(trie_t *old, trie_cb *mark_shared, void *d);
+
+/*! \brief get the new trie from a COW context */
+trie_t* trie_cow_new(trie_cow_t *cow);
+
+/*! \brief variant of trie_get_ins() for use during COW transactions
+ *
+ * As necessary, this copies path from the root of the trie to the
+ * leaf, so that it is no longer shared. Any leaves adjacent to this
+ * path are marked as shared using the mark_shared callback passed to
+ * trie_cow().
+ *
+ * It is your responsibility to COW your trie_val_t objects. If you copy an
+ * object you must change the original's reachability from shared to old-only.
+ * New objects (including copies) must have new-only reachability.
+ */
+trie_val_t* trie_get_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len);
+
+/*!
+ * \brief variant of trie_del() for use during COW transactions
+ *
+ * The mark_shared callback is invoked as necessary, in the same way
+ * as trie_get_cow().
+ *
+ * Returns KNOT_EOK if the key was removed or KNOT_ENOENT if not found.
+ * If val!=NULL and deletion succeeded, the *val is set to the deleted
+ * value pointer.
+ */
+int trie_del_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len, trie_val_t *val);
+
+/*! \brief clean up the old trie after committing a COW transaction
+ *
+ * Your callback is invoked for any trie_val_t objects that might need
+ * cleaning up; you must free any objects you have marked as old-only
+ * and retain objects with shared reachability.
+ *
+ * \note The callback can be NULL.
+ *
+ * The cow object is free()d, and the new trie root is returned.
+ */
+trie_t* trie_cow_commit(trie_cow_t *cow, trie_cb *cb, void *d);
+
+/*! \brief clean up the new trie after rolling back a COW transaction
+ *
+ * Your callback is invoked for any trie_val_t objects that might need
+ * cleaning up; you must free any objects you have marked as new-only
+ * and retain objects with shared reachability.
+ *
+ * \note The callback can be NULL.
+ *
+ * The cow object is free()d, and the old trie root is returned.
+ */
+trie_t* trie_cow_rollback(trie_cow_t *cow, trie_cb *cb, void *d);
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <assert.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#if defined(HAVE_EXPLICIT_BZERO)
+ #if defined(HAVE_BSD_STRING_H)
+ #include <bsd/string.h>
+ #endif
+ /* #include <string.h> is needed. */
+#elif defined(HAVE_EXPLICIT_MEMSET)
+ /* #include <string.h> is needed. */
+#elif defined(HAVE_GNUTLS_MEMSET)
+ #include <gnutls/gnutls.h>
+#else
+ #define USE_CUSTOM_MEMSET
+#endif
+
+#include "string.h"
+#include "ctype.h"
+
+uint8_t *memdup(const uint8_t *data, size_t data_size)
+{
+ uint8_t *result = (uint8_t *)malloc(data_size);
+ if (!result) {
+ return NULL;
+ }
+
+ return memcpy(result, data, data_size);
+}
+
+char *sprintf_alloc(const char *fmt, ...)
+{
+ char *strp = NULL;
+ va_list ap;
+
+ va_start(ap, fmt);
+ int ret = vasprintf(&strp, fmt, ap);
+ va_end(ap);
+
+ if (ret < 0) {
+ return NULL;
+ }
+ return strp;
+}
+
+char *strcdup(const char *s1, const char *s2)
+{
+ if (!s1 || !s2) {
+ return NULL;
+ }
+
+ size_t s1len = strlen(s1);
+ size_t s2len = strlen(s2);
+ size_t nlen = s1len + s2len + 1;
+
+ char* dst = malloc(nlen);
+ if (dst == NULL) {
+ return NULL;
+ }
+
+ memcpy(dst, s1, s1len);
+ memcpy(dst + s1len, s2, s2len + 1);
+ return dst;
+}
+
+char *strstrip(const char *str)
+{
+ // leading white-spaces
+ const char *scan = str;
+ while (is_space(scan[0])) {
+ scan += 1;
+ }
+
+ // trailing white-spaces
+ size_t len = strlen(scan);
+ while (len > 0 && is_space(scan[len - 1])) {
+ len -= 1;
+ }
+
+ char *trimmed = malloc(len + 1);
+ if (!trimmed) {
+ return NULL;
+ }
+
+ memcpy(trimmed, scan, len);
+ trimmed[len] = '\0';
+
+ return trimmed;
+}
+
+int const_time_memcmp(const void *s1, const void *s2, size_t n)
+{
+ volatile uint8_t equal = 0;
+
+ for (size_t i = 0; i < n; i++) {
+ equal |= ((uint8_t *)s1)[i] ^ ((uint8_t *)s2)[i];
+ }
+
+ return equal;
+}
+
+#if defined(USE_CUSTOM_MEMSET)
+typedef void *(*memset_t)(void *, int, size_t);
+static volatile memset_t volatile_memset = memset;
+#endif
+
+void *memzero(void *s, size_t n)
+{
+#if defined(HAVE_EXPLICIT_BZERO) /* In OpenBSD since 5.5. */
+ /* In FreeBSD since 11.0. */
+ /* In glibc since 2.25. */
+ /* In DragonFly BSD since 5.5. */
+ explicit_bzero(s, n);
+ return s;
+#elif defined(HAVE_EXPLICIT_MEMSET) /* In NetBSD since 7.0. */
+ return explicit_memset(s, 0, n);
+#elif defined(HAVE_GNUTLS_MEMSET) /* In GnuTLS since 3.4.0. */
+ gnutls_memset(s, 0, n);
+ return s;
+#else /* Knot custom solution as a fallback. */
+ /* Warning: the use of the return value is *probably* needed
+ * so as to avoid the volatile_memset() to be optimized out.
+ */
+ return volatile_memset(s, 0, n);
+#endif
+}
+
+static const char BIN_TO_HEX[] = {
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'
+};
+
+char *bin_to_hex(const uint8_t *bin, size_t bin_len)
+{
+ if (bin == NULL) {
+ return NULL;
+ }
+
+ size_t hex_size = bin_len * 2;
+ char *hex = malloc(hex_size + 1);
+ if (hex == NULL) {
+ return NULL;
+ }
+
+ for (size_t i = 0; i < bin_len; i++) {
+ hex[2 * i] = BIN_TO_HEX[bin[i] >> 4];
+ hex[2 * i + 1] = BIN_TO_HEX[bin[i] & 0x0f];
+ }
+ hex[hex_size] = '\0';
+
+ return hex;
+}
+
+/*!
+ * Convert HEX character to numeric value (assumes valid input).
+ */
+static uint8_t hex_to_number(const char hex)
+{
+ if (hex >= '0' && hex <= '9') {
+ return hex - '0';
+ } else if (hex >= 'a' && hex <= 'f') {
+ return hex - 'a' + 10;
+ } else {
+ assert(hex >= 'A' && hex <= 'F');
+ return hex - 'A' + 10;
+ }
+}
+
+uint8_t *hex_to_bin(const char *hex, size_t *out_len)
+{
+ if (hex == NULL || out_len == NULL) {
+ return NULL;
+ }
+
+ size_t hex_len = strlen(hex);
+ if (hex_len % 2 != 0) {
+ return NULL;
+ }
+
+ size_t bin_len = hex_len / 2;
+ uint8_t *bin = malloc(bin_len + 1);
+ if (bin == NULL) {
+ return NULL;
+ }
+
+ for (size_t i = 0; i < bin_len; i++) {
+ if (!is_xdigit(hex[2 * i]) || !is_xdigit(hex[2 * i + 1])) {
+ free(bin);
+ return NULL;
+ }
+ uint8_t high = hex_to_number(hex[2 * i]);
+ uint8_t low = hex_to_number(hex[2 * i + 1]);
+ bin[i] = high << 4 | low;
+ }
+
+ *out_len = bin_len;
+
+ return bin;
+}
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/*!
+ * \brief String manipulations.
+ */
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+/*!
+ * \brief Create a copy of a binary buffer.
+ *
+ * Like \c strdup, but for binary data.
+ */
+uint8_t *memdup(const uint8_t *data, size_t data_size);
+
+/*!
+ * \brief Format string and take care of allocating memory.
+ *
+ * \note sprintf(3) manual page reference implementation.
+ *
+ * \param fmt Message format.
+ * \return formatted message or NULL.
+ */
+char *sprintf_alloc(const char *fmt, ...);
+
+/*!
+ * \brief Create new string from a concatenation of s1 and s2.
+ *
+ * \param s1 First string.
+ * \param s2 Second string.
+ *
+ * \retval Newly allocated string on success.
+ * \retval NULL on error.
+ */
+char *strcdup(const char *s1, const char *s2);
+
+/*!
+ * \brief Create a copy of a string skipping leading and trailing white spaces.
+ *
+ * \return Newly allocated string, NULL in case of error.
+ */
+char *strstrip(const char *str);
+
+/*!
+ * \brief Compare data in time based on string length.
+ * This function just checks for (in)equality not for relation
+ *
+ * \param s1 The first address to compare.
+ * \param s2 The second address to compare.
+ * \param n The size of memory to compare.
+ *
+ * \return Non zero on difference and zero if the buffers are identical.
+ */
+int const_time_memcmp(const void *s1, const void *s2, size_t n);
+
+/*!
+ * \brief Fill memory with zeroes.
+ *
+ * Inspired by OPENSSL_cleanse. Such a memset shouldn't be optimized out.
+ *
+ * \param s The address to fill.
+ * \param n The size of memory to fill.
+ *
+ * \return Pointer to the memory.
+ */
+void *memzero(void *s, size_t n);
+
+/*!
+ * \brief Convert binary data to hexadecimal string.
+ */
+char *bin_to_hex(const uint8_t *bin, size_t bin_len);
+
+/*!
+ * \brief Convert hex encoded string to binary data.
+ */
+uint8_t *hex_to_bin(const char *hex, size_t *out_len);
--- /dev/null
+../licenses/LGPL-2.0
\ No newline at end of file
--- /dev/null
+/*
+ * UCW Library -- Universal Simple Array Sorter
+ *
+ * (c) 2003--2008 Martin Mares <mj@ucw.cz>
+ *
+ * This software may be freely distributed and used according to the terms
+ * of the GNU Lesser General Public License.
+ */
+
+#pragma once
+
+#include "contrib/macros.h"
+
+/*
+ * This is not a normal header file, it's a generator of sorting
+ * routines. Each time you include it with parameters set in the
+ * corresponding preprocessor macros, it generates an array sorter
+ * with the parameters given.
+ *
+ * You might wonder why the heck do we implement our own array sorter
+ * instead of using qsort(). The primary reason is that qsort handles
+ * only continuous arrays, but we need to sort array-like data structures
+ * where the only way to access elements is by using an indexing macro.
+ * Besides that, we are more than 2 times faster.
+ *
+ * So much for advocacy, there are the parameters (those marked with [*]
+ * are mandatory):
+ *
+ * ASORT_PREFIX(x) [*] add a name prefix (used on all global names
+ * defined by the sorter)
+ * ASORT_KEY_TYPE [*] data type of a single array entry key
+ * ASORT_ELT(i) returns the key of i-th element; if this macro is not
+ * defined, the function gets a pointer to an array to be sorted
+ * ASORT_LT(x,y) x < y for ASORT_KEY_TYPE (default: "x<y")
+ * ASORT_SWAP(i,j) swap i-th and j-th element (default: assume _ELT
+ * is an l-value and swap just the keys)
+ * ASORT_THRESHOLD threshold for switching between quicksort and insertsort
+ * ASORT_EXTRA_ARGS extra arguments for the sort function (they are always
+ * visible in all the macros supplied above), starts with comma
+ *
+ * After including this file, a function ASORT_PREFIX(sort)(unsigned array_size)
+ * or ASORT_PREFIX(sort)(ASORT_KEY_TYPE *array, unsigned array_size) [if ASORT_ELT
+ * is not defined] is declared and all parameter macros are automatically
+ * undef'd.
+ */
+
+#ifndef ASORT_LT
+#define ASORT_LT(x,y) ((x) < (y))
+#endif
+
+#ifndef ASORT_SWAP
+#define ASORT_SWAP(i,j) do { ASORT_KEY_TYPE tmp = ASORT_ELT(i); ASORT_ELT(i)=ASORT_ELT(j); ASORT_ELT(j)=tmp; } while (0)
+#endif
+
+#ifndef ASORT_THRESHOLD
+#define ASORT_THRESHOLD 8 /* Guesswork and experimentation */
+#endif
+
+#ifndef ASORT_EXTRA_ARGS
+#define ASORT_EXTRA_ARGS
+#endif
+
+#ifndef ASORT_ELT
+#define ASORT_ARRAY_ARG ASORT_KEY_TYPE *array,
+#define ASORT_ELT(i) array[i]
+#else
+#define ASORT_ARRAY_ARG
+#endif
+
+/**
+ * The generated sorting function. If `ASORT_ELT` macro is not provided, the
+ * @ASORT_ARRAY_ARG is equal to `ASORT_KEY_TYPE *array` and is the array to be
+ * sorted. If the macro is provided, this parameter is omitted. In that case,
+ * you can sort global variables or pass your structure by @ASORT_EXTRA_ARGS.
+ **/
+static void ASORT_PREFIX(sort)(ASORT_ARRAY_ARG unsigned array_size ASORT_EXTRA_ARGS)
+{
+ struct stk { int l, r; } stack[8*sizeof(unsigned)];
+ int l, r, left, right, m;
+ unsigned sp = 0;
+ ASORT_KEY_TYPE pivot;
+
+ if (array_size <= 1)
+ return;
+
+ /* QuickSort with optimizations a'la Sedgewick, but stop at ASORT_THRESHOLD */
+
+ left = 0;
+ right = array_size - 1;
+ for(;;)
+ {
+ l = left;
+ r = right;
+ m = (l+r)/2;
+ if (ASORT_LT(ASORT_ELT(m), ASORT_ELT(l)))
+ ASORT_SWAP(l,m);
+ if (ASORT_LT(ASORT_ELT(r), ASORT_ELT(m)))
+ {
+ ASORT_SWAP(m,r);
+ if (ASORT_LT(ASORT_ELT(m), ASORT_ELT(l)))
+ ASORT_SWAP(l,m);
+ }
+ pivot = ASORT_ELT(m);
+ do
+ {
+ while (ASORT_LT(ASORT_ELT(l), pivot))
+ l++;
+ while (ASORT_LT(pivot, ASORT_ELT(r)))
+ r--;
+ if (l < r)
+ {
+ ASORT_SWAP(l,r);
+ l++;
+ r--;
+ }
+ else if (l == r)
+ {
+ l++;
+ r--;
+ }
+ }
+ while (l <= r);
+ if ((r - left) >= ASORT_THRESHOLD && (right - l) >= ASORT_THRESHOLD)
+ {
+ /* Both partitions ok => push the larger one */
+ if ((r - left) > (right - l))
+ {
+ stack[sp].l = left;
+ stack[sp].r = r;
+ left = l;
+ }
+ else
+ {
+ stack[sp].l = l;
+ stack[sp].r = right;
+ right = r;
+ }
+ sp++;
+ }
+ else if ((r - left) >= ASORT_THRESHOLD)
+ {
+ /* Left partition OK, right undersize */
+ right = r;
+ }
+ else if ((right - l) >= ASORT_THRESHOLD)
+ {
+ /* Right partition OK, left undersize */
+ left = l;
+ }
+ else
+ {
+ /* Both partitions undersize => pop */
+ if (!sp)
+ break;
+ sp--;
+ left = stack[sp].l;
+ right = stack[sp].r;
+ }
+ }
+
+ /*
+ * We have a partially sorted array, finish by insertsort. Inspired
+ * by qsort() in GNU libc.
+ */
+
+ /* Find minimal element which will serve as a barrier */
+ r = MIN(array_size, ASORT_THRESHOLD);
+ m = 0;
+ for (l=1; l<r; l++)
+ if (ASORT_LT(ASORT_ELT(l),ASORT_ELT(m)))
+ m = l;
+ ASORT_SWAP(0,m);
+
+ /* Insertion sort */
+ for (m=1; m<(int)array_size; m++)
+ {
+ l=m;
+ while (ASORT_LT(ASORT_ELT(m),ASORT_ELT(l-1)))
+ l--;
+ while (l < m)
+ {
+ ASORT_SWAP(l,m);
+ l++;
+ }
+ }
+}
+
+#undef ASORT_PREFIX
+#undef ASORT_KEY_TYPE
+#undef ASORT_ELT
+#undef ASORT_LT
+#undef ASORT_SWAP
+#undef ASORT_THRESHOLD
+#undef ASORT_EXTRA_ARGS
+#undef ASORT_ARRAY_ARG
--- /dev/null
+/*
+ * UCW Library -- Generic Binary Search
+ *
+ * (c) 2005 Martin Mares <mj@ucw.cz>
+ *
+ * This software may be freely distributed and used according to the terms
+ * of the GNU Lesser General Public License.
+ */
+
+#pragma once
+
+/***
+ * [[defs]]
+ * Definitions
+ * -----------
+ ***/
+
+/**
+ * Find the first element not lower than \p x in the sorted array \p ary of \p N elements (non-decreasing order).
+ * Returns the index of the found element or \p N if no exists. Uses `ary_lt_x(ary,i,x)` to compare the i'th element with \p x.
+ * The time complexity is `O(log(N))`.
+ **/
+#define BIN_SEARCH_FIRST_GE_CMP(ary, N, ary_lt_x, x, ...) ({ \
+ unsigned l = 0, r = (N); \
+ while (l < r) \
+ { \
+ unsigned m = (l+r)/2; \
+ if (ary_lt_x(ary, m, x, __VA_ARGS__)) \
+ l = m+1; \
+ else \
+ r = m; \
+ } \
+ l; \
+})
+
+/**
+ * The default comparison macro for \ref BIN_SEARCH_FIRST_GE_CMP().
+ **/
+#define ARY_LT_NUM(ary,i,x) (ary)[i] < (x)
+
+/**
+ * Same as \ref BIN_SEARCH_FIRST_GE_CMP(), but uses the default `<` operator for comparisons.
+ **/
+#define BIN_SEARCH_FIRST_GE(ary,N,x) BIN_SEARCH_FIRST_GE_CMP(ary,N,ARY_LT_NUM,x)
+
+/**
+ * Search the sorted array \p ary of \p N elements (non-decreasing) for the first occurrence of \p x.
+ * Returns the index or -1 if no such element exists. Uses the `<` operator for comparisons.
+ **/
+#define BIN_SEARCH_EQ(ary,N,x) ({ int i = BIN_SEARCH_FIRST_GE(ary,N,x); if (i >= (N) || (ary)[i] != (x)) i=-1; i; })
--- /dev/null
+/*
+ * Binary heap
+ *
+ * (c) 2012 Ondrej Filip <feela@network.cz>
+ *
+ * This software may be freely distributed and used according to the terms
+ * of the GNU Lesser General Public License.
+ */
+
+/***
+ * Introduction
+ * ------------
+ *
+ * Binary heap is a simple data structure, which for example supports efficient insertions, deletions
+ * and access to the minimal inserted item. We define several macros for such operations.
+ * Note that because of simplicity of heaps, we have decided to define direct macros instead
+ * of a <<generic:,macro generator>> as for several other data structures in the Libucw.
+ *
+ * A heap is represented by a number of elements and by an array of values. Beware that we
+ * index this array from one, not from zero as do the standard C arrays.
+ *
+ * Most macros use these parameters:
+ *
+ * - @num - a variable (signed or unsigned integer) with the number of elements
+ * - @heap - a C array of type @type; the heap is stored in `heap[1] .. heap[num]`; `heap[0]` is unused
+ *
+ * A valid heap must follow these rules:
+ *
+ * - `num >= 0`
+ * - `heap[i] >= heap[i / 2]` for each `i` in `[2, num]`
+ *
+ * The first element `heap[1]` is always lower or equal to all other elements.
+ ***/
+
+#include <string.h>
+#include <stdlib.h>
+#include "contrib/ucw/heap.h"
+
+static inline void heap_swap(heap_val_t **e1, heap_val_t **e2)
+{
+ if (e1 == e2) return; /* Stack tmp should be faster than tmpelem. */
+ heap_val_t *tmp = *e1; /* Even faster than 2-XOR nowadays. */
+ *e1 = *e2;
+ *e2 = tmp;
+ int pos = (*e1)->pos;
+ (*e1)->pos = (*e2)->pos;
+ (*e2)->pos = pos;
+}
+
+int heap_init(struct heap *h, int (*cmp)(void *, void *), int init_size)
+{
+ int isize = init_size ? init_size : INITIAL_HEAP_SIZE;
+
+ h->num = 0;
+ h->max_size = isize;
+ h->cmp = cmp;
+ h->data = malloc((isize + 1) * sizeof(heap_val_t*)); /* Temp element unused. */
+
+ return h->data ? 1 : 0;
+}
+
+void heap_deinit(struct heap *h)
+{
+ free(h->data);
+ memset(h, 0, sizeof(*h));
+}
+
+static inline void _heap_bubble_down(struct heap *h, int e)
+{
+ int e1;
+ for (;;)
+ {
+ e1 = 2*e;
+ if(e1 > h->num) break;
+ if((h->cmp(*HELEMENT(h, e),*HELEMENT(h,e1)) < 0) && (e1 == h->num || (h->cmp(*HELEMENT(h, e),*HELEMENT(h,e1+1)) < 0))) break;
+ if((e1 != h->num) && (h->cmp(*HELEMENT(h, e1+1), *HELEMENT(h,e1)) < 0)) e1++;
+ heap_swap(HELEMENT(h,e),HELEMENT(h,e1));
+ e = e1;
+ }
+}
+
+static inline void _heap_bubble_up(struct heap *h, int e)
+{
+ int e1;
+ while (e > 1)
+ {
+ e1 = e/2;
+ if(h->cmp(*HELEMENT(h, e1),*HELEMENT(h,e)) < 0) break;
+ heap_swap(HELEMENT(h,e),HELEMENT(h,e1));
+ e = e1;
+ }
+
+}
+
+static void heap_increase(struct heap *h, int pos, heap_val_t *e)
+{
+ *HELEMENT(h, pos) = e;
+ e->pos = pos;
+ _heap_bubble_down(h, pos);
+}
+
+static void heap_decrease(struct heap *h, int pos, heap_val_t *e)
+{
+ *HELEMENT(h, pos) = e;
+ e->pos = pos;
+ _heap_bubble_up(h, pos);
+}
+
+void heap_replace(struct heap *h, int pos, heap_val_t *e)
+{
+ if (h->cmp(*HELEMENT(h, pos),e) < 0) {
+ heap_increase(h, pos, e);
+ } else {
+ heap_decrease(h, pos, e);
+ }
+}
+
+void heap_delmin(struct heap *h)
+{
+ if(h->num == 0) return;
+ if(h->num > 1)
+ {
+ heap_swap(HHEAD(h),HELEMENT(h,h->num));
+ }
+ (*HELEMENT(h, h->num))->pos = 0;
+ --h->num;
+ _heap_bubble_down(h, 1);
+}
+
+int heap_insert(struct heap *h, heap_val_t *e)
+{
+ if(h->num == h->max_size)
+ {
+ h->max_size = h->max_size * HEAP_INCREASE_STEP;
+ h->data = realloc(h->data, (h->max_size + 1) * sizeof(heap_val_t*));
+ if (!h->data) {
+ return 0;
+ }
+ }
+
+ h->num++;
+ *HELEMENT(h,h->num) = e;
+ e->pos = h->num;
+ _heap_bubble_up(h,h->num);
+ return 1;
+}
+
+int heap_find(struct heap *h, heap_val_t *elm)
+{
+ return ((struct heap_val *) elm)->pos;
+}
+
+void heap_delete(struct heap *h, int e)
+{
+ heap_swap(HELEMENT(h, e), HELEMENT(h, h->num));
+ (*HELEMENT(h, h->num))->pos = 0;
+ h->num--;
+ if(h->cmp(*HELEMENT(h, e), *HELEMENT(h, h->num + 1)) < 0) _heap_bubble_up(h, e);
+ else _heap_bubble_down(h, e);
+
+ if ((h->num > INITIAL_HEAP_SIZE) && (h->num < h->max_size / HEAP_DECREASE_THRESHOLD))
+ {
+ h->max_size = h->max_size / HEAP_INCREASE_STEP;
+ h->data = realloc(h->data, (h->max_size + 1) * sizeof(heap_val_t*));
+ }
+}
--- /dev/null
+/* Copyright (C) 2011 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+struct heap_val {
+ int pos;
+};
+
+typedef struct heap_val heap_val_t;
+
+struct heap {
+ int num; /* Number of elements */
+ int max_size; /* Size of allocated memory */
+ int (*cmp)(void *, void *);
+ heap_val_t **data;
+}; /* Array follows */
+
+#define INITIAL_HEAP_SIZE 512 /* initial heap size */
+#define HEAP_INCREASE_STEP 2 /* multiplier for each inflation, keep conservative */
+#define HEAP_DECREASE_THRESHOLD 2 /* threshold for deflation, keep conservative */
+#define HELEMENT(h,num) ((h)->data + (num))
+#define HHEAD(h) HELEMENT((h), 1)
+#define EMPTY_HEAP(h) ((h)->num == 0) /* h->num == 0 */
+
+int heap_init(struct heap *, int (*cmp)(void *, void *), int);
+void heap_deinit(struct heap *);
+
+void heap_delmin(struct heap *);
+int heap_insert(struct heap *, heap_val_t *);
+int heap_find(struct heap *, heap_val_t *);
+void heap_delete(struct heap *, int);
+void heap_replace(struct heap *h, int pos, heap_val_t *);
--- /dev/null
+/*
+ * BIRD Library -- Linked Lists
+ *
+ * (c) 1998 Martin Mares <mj@ucw.cz>
+ * (c) 2015, 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ *
+ * Can be freely distributed and used under the terms of the GNU GPL.
+ */
+
+/**
+ * DOC: Linked lists
+ *
+ * The BIRD library provides a set of functions for operating on linked
+ * lists. The lists are internally represented as standard doubly linked
+ * lists with synthetic head and tail which makes all the basic operations
+ * run in constant time and contain no extra end-of-list checks. Each list
+ * is described by a &list structure, nodes can have any format as long
+ * as they start with a &node structure. If you want your nodes to belong
+ * to multiple lists at once, you can embed multiple &node structures in them
+ * and use the SKIP_BACK() macro to calculate a pointer to the start of the
+ * structure from a &node pointer, but beware of obscurity.
+ *
+ * There also exist safe linked lists (&slist, &snode and all functions
+ * being prefixed with |s_|) which support asynchronous walking very
+ * similar to that used in the &fib structure.
+ */
+
+#include <stdlib.h>
+#include <string.h>
+#include "contrib/ucw/lists.h"
+#include "contrib/mempattern.h"
+
+/**
+ * add_tail - append a node to a list
+ * \p l: linked list
+ * \p n: list node
+ *
+ * add_tail() takes a node \p n and appends it at the end of the list \p l.
+ */
+void
+add_tail(list_t *l, node_t *n)
+{
+ node_t *z = l->tail;
+
+ n->next = (node_t *) &l->null;
+ n->prev = z;
+ z->next = n;
+ l->tail = n;
+}
+
+/**
+ * add_head - prepend a node to a list
+ * \p l: linked list
+ * \p n: list node
+ *
+ * add_head() takes a node \p n and prepends it at the start of the list \p l.
+ */
+void
+add_head(list_t *l, node_t *n)
+{
+ node_t *z = l->head;
+
+ n->next = z;
+ n->prev = (node_t *) &l->head;
+ z->prev = n;
+ l->head = n;
+}
+
+/**
+ * insert_node - insert a node to a list
+ * \p n: a new list node
+ * \p after: a node of a list
+ *
+ * Inserts a node \p n to a linked list after an already inserted
+ * node \p after.
+ */
+void
+insert_node(node_t *n, node_t *after)
+{
+ node_t *z = after->next;
+
+ n->next = z;
+ n->prev = after;
+ after->next = n;
+ z->prev = n;
+}
+
+/**
+ * rem_node - remove a node from a list
+ * \p n: node to be removed
+ *
+ * Removes a node \p n from the list it's linked in.
+ */
+void
+rem_node(node_t *n)
+{
+ node_t *z = n->prev;
+ node_t *x = n->next;
+
+ z->next = x;
+ x->prev = z;
+ n->prev = 0;
+ n->next = 0;
+}
+
+/**
+ * init_list - create an empty list
+ * \p l: list
+ *
+ * init_list() takes a &list structure and initializes its
+ * fields, so that it represents an empty list.
+ */
+void
+init_list(list_t *l)
+{
+ l->head = (node_t *) &l->null;
+ l->null = NULL;
+ l->tail = (node_t *) &l->head;
+}
+
+/**
+ * add_tail_list - concatenate two lists
+ * \p to: destination list
+ * \p l: source list
+ *
+ * This function appends all elements of the list \p l to
+ * the list \p to in constant time.
+ */
+void
+add_tail_list(list_t *to, list_t *l)
+{
+ node_t *p = to->tail;
+ node_t *q = l->head;
+
+ p->next = q;
+ q->prev = p;
+ q = l->tail;
+ q->next = (node_t *) &to->null;
+ to->tail = q;
+}
+
+/**
+ * list_dup - duplicate list
+ * \p to: destination list
+ * \p l: source list
+ *
+ * This function duplicates all elements of the list \p l to
+ * the list \p to in linear time.
+ *
+ * This function only works with a homogenous item size.
+ */
+void list_dup(list_t *dst, list_t *src, size_t itemsz)
+{
+ node_t *n;
+ WALK_LIST(n, *src) {
+ node_t *i = malloc(itemsz);
+ memcpy(i, n, itemsz);
+ add_tail(dst, i);
+ }
+}
+
+/**
+ * list_size - gets number of nodes
+ * \p l: list
+ *
+ * This function counts nodes in list \p l and returns this number.
+ */
+size_t list_size(const list_t *l)
+{
+ size_t count = 0;
+
+ node_t *n;
+ WALK_LIST(n, *l) {
+ count++;
+ }
+
+ return count;
+}
+
+/**
+ * ptrlist_add - add pointer to pointer list
+ * \p to: destination list
+ * \p val: added pointer
+ * \p mm: memory context
+ */
+ptrnode_t *ptrlist_add(list_t *to, void *val, knot_mm_t *mm)
+{
+ ptrnode_t *node = mm_alloc(mm , sizeof(ptrnode_t));
+ if (node == NULL) {
+ return NULL;
+ } else {
+ node->d = val;
+ }
+ add_tail(to, &node->n);
+ return node;
+}
+
+/**
+ * ptrlist_free - free all nodes in pointer list
+ * \p list: list nodes
+ * \p mm: memory context
+ */
+void ptrlist_free(list_t *list, knot_mm_t *mm)
+{
+ node_t *n, *nxt;
+ WALK_LIST_DELSAFE(n, nxt, *list) {
+ mm_free(mm, n);
+ }
+ init_list(list);
+}
+
+/**
+ * ptrlist_rem - remove pointer node
+ * \p val: pointer to remove
+ * \p mm: memory context
+ */
+void ptrlist_rem(ptrnode_t *node, knot_mm_t *mm)
+{
+ rem_node(&node->n);
+ mm_free(mm, node);
+}
+
+/**
+ * ptrlist_deep_free - free all nodes incl referenced data
+ * \p list: list nodes
+ * \p mm: memory context
+ */
+void ptrlist_deep_free(list_t *l, knot_mm_t *mm)
+{
+ ptrnode_t *n;
+ WALK_LIST(n, *l) {
+ mm_free(mm, n->d);
+ }
+ ptrlist_free(l, mm);
+}
--- /dev/null
+/*
+ * BIRD Library -- Linked Lists
+ *
+ * (c) 1998 Martin Mares <mj@ucw.cz>
+ * (c) 2015, 2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ *
+ * Can be freely distributed and used under the terms of the GNU GPL.
+ */
+
+#pragma once
+
+/*
+ * I admit the list structure is very tricky and also somewhat awkward,
+ * but it's both efficient and easy to manipulate once one understands the
+ * basic trick: The list head always contains two synthetic nodes which are
+ * always present in the list: the head and the tail. But as the `next'
+ * entry of the tail and the `prev' entry of the head are both NULL, the
+ * nodes can overlap each other:
+ *
+ * head head_node.next
+ * null head_node.prev tail_node.next
+ * tail tail_node.prev
+ */
+
+#include <string.h>
+#include "libknot/mm_ctx.h"
+
+typedef struct node {
+ struct node *next, *prev;
+} node_t;
+
+typedef struct list { /* In fact two overlayed nodes */
+ struct node *head, *null, *tail;
+} list_t;
+
+#define NODE (node_t *)
+#define HEAD(list) ((void *)((list).head))
+#define TAIL(list) ((void *)((list).tail))
+#define WALK_LIST(n,list) for(n=HEAD(list);(NODE (n))->next; \
+ n=(void *)((NODE (n))->next))
+#define WALK_LIST_DELSAFE(n,nxt,list) \
+ for(n=HEAD(list); (nxt=(void *)((NODE (n))->next)); n=(void *) nxt)
+/* WALK_LIST_FIRST supposes that called code removes each processed node */
+#define WALK_LIST_FIRST(n,list) \
+ while(n=HEAD(list), (NODE (n))->next)
+#define WALK_LIST_BACKWARDS(n,list) for(n=TAIL(list);(NODE (n))->prev; \
+ n=(void *)((NODE (n))->prev))
+#define WALK_LIST_BACKWARDS_DELSAFE(n,prv,list) \
+ for(n=TAIL(list); prv=(void *)((NODE (n))->prev); n=(void *) prv)
+
+#define EMPTY_LIST(list) (!(list).head->next)
+
+/*! \brief Free every node in the list. */
+#define WALK_LIST_FREE(list) \
+ do { \
+ node_t *n=0,*nxt=0; \
+ WALK_LIST_DELSAFE(n,nxt,list) { \
+ free(n); \
+ } \
+ init_list(&list); \
+ } while(0)
+
+void add_tail(list_t *, node_t *);
+void add_head(list_t *, node_t *);
+void rem_node(node_t *);
+void add_tail_list(list_t *, list_t *);
+void init_list(list_t *);
+void insert_node(node_t *, node_t *);
+void list_dup(list_t *dst, list_t *src, size_t itemsz);
+size_t list_size(const list_t *);
+
+/*!
+ * \brief Generic pointer list implementation.
+ */
+typedef struct ptrnode {
+ node_t n;
+ void *d;
+} ptrnode_t;
+
+ptrnode_t *ptrlist_add(list_t *, void *, knot_mm_t *);
+void ptrlist_free(list_t *, knot_mm_t *);
+void ptrlist_rem(ptrnode_t *node, knot_mm_t *mm);
+void ptrlist_deep_free(list_t *, knot_mm_t *);
+
--- /dev/null
+/*
+ * UCW Library -- Memory Pools (One-Time Allocation)
+ *
+ * (c) 1997--2001 Martin Mares <mj@ucw.cz>
+ * (c) 2007 Pavel Charvat <pchar@ucw.cz>
+ * (c) 2015, 2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ *
+ * This software may be freely distributed and used according to the terms
+ * of the GNU Lesser General Public License.
+ */
+
+#undef LOCAL_DEBUG
+
+#include <string.h>
+#include <strings.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <assert.h>
+#include "contrib/asan.h"
+#include "contrib/macros.h"
+#include "contrib/ucw/mempool.h"
+
+/** \todo This shouldn't be precalculated, but computed on load. */
+#define CPU_PAGE_SIZE 4096
+
+/** Align an integer \p s to the nearest higher multiple of \p a (which should be a power of two) **/
+#define ALIGN_TO(s, a) (((s)+a-1)&~(a-1))
+#define MP_CHUNK_TAIL ALIGN_TO(sizeof(struct mempool_chunk), CPU_STRUCT_ALIGN)
+#define MP_SIZE_MAX (~0U - MP_CHUNK_TAIL - CPU_PAGE_SIZE)
+#define DBG(s, ...)
+
+/** \note Imported MMAP backend from bigalloc.c */
+#define CONFIG_UCW_POOL_IS_MMAP
+#ifdef CONFIG_UCW_POOL_IS_MMAP
+#include <sys/mman.h>
+static void *
+page_alloc(uint64_t len)
+{
+ if (!len) {
+ return NULL;
+ }
+ if (len > SIZE_MAX) {
+ return NULL;
+ }
+ assert(!(len & (CPU_PAGE_SIZE-1)));
+ uint8_t *p = mmap(NULL, len, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0);
+ if (p == (uint8_t*) MAP_FAILED) {
+ return NULL;
+ }
+ return p;
+}
+
+static void
+page_free(void *start, uint64_t len)
+{
+ assert(!(len & (CPU_PAGE_SIZE-1)));
+ assert(!((uintptr_t) start & (CPU_PAGE_SIZE-1)));
+ munmap(start, len);
+}
+#endif
+
+struct mempool_chunk {
+ struct mempool_chunk *next;
+ unsigned size;
+};
+
+static unsigned
+mp_align_size(unsigned size)
+{
+#ifdef CONFIG_UCW_POOL_IS_MMAP
+ return ALIGN_TO(size + MP_CHUNK_TAIL, CPU_PAGE_SIZE) - MP_CHUNK_TAIL;
+#else
+ return ALIGN_TO(size, CPU_STRUCT_ALIGN);
+#endif
+}
+
+void
+mp_init(struct mempool *pool, unsigned chunk_size)
+{
+ chunk_size = mp_align_size(MAX(sizeof(struct mempool), chunk_size));
+ *pool = (struct mempool) {
+ .chunk_size = chunk_size,
+ .threshold = chunk_size >> 1,
+ .last_big = &pool->last_big
+ };
+}
+
+static void *
+mp_new_big_chunk(unsigned size)
+{
+ uint8_t *data = malloc(size + MP_CHUNK_TAIL);
+ if (!data) {
+ return NULL;
+ }
+ ASAN_POISON_MEMORY_REGION(data, size);
+ struct mempool_chunk *chunk = (struct mempool_chunk *)(data + size);
+ chunk->size = size;
+ return chunk;
+}
+
+static void
+mp_free_big_chunk(struct mempool_chunk *chunk)
+{
+ void *ptr = (uint8_t *)chunk - chunk->size;
+ ASAN_UNPOISON_MEMORY_REGION(ptr, chunk->size);
+ free(ptr);
+}
+
+static void *
+mp_new_chunk(unsigned size)
+{
+#ifdef CONFIG_UCW_POOL_IS_MMAP
+ uint8_t *data = page_alloc(size + MP_CHUNK_TAIL);
+ if (!data) {
+ return NULL;
+ }
+ ASAN_POISON_MEMORY_REGION(data, size);
+ struct mempool_chunk *chunk = (struct mempool_chunk *)(data + size);
+ chunk->size = size;
+ return chunk;
+#else
+ return mp_new_big_chunk(size);
+#endif
+}
+
+static void
+mp_free_chunk(struct mempool_chunk *chunk)
+{
+#ifdef CONFIG_UCW_POOL_IS_MMAP
+ uint8_t *data = (uint8_t *)chunk - chunk->size;
+ ASAN_UNPOISON_MEMORY_REGION(data, chunk->size);
+ page_free(data, chunk->size + MP_CHUNK_TAIL);
+#else
+ mp_free_big_chunk(chunk);
+#endif
+}
+
+struct mempool *
+mp_new(unsigned chunk_size)
+{
+ chunk_size = mp_align_size(MAX(sizeof(struct mempool), chunk_size));
+ struct mempool_chunk *chunk = mp_new_chunk(chunk_size);
+ struct mempool *pool = (void *)chunk - chunk_size;
+ ASAN_UNPOISON_MEMORY_REGION(pool, sizeof(*pool));
+ DBG("Creating mempool %p with %u bytes long chunks", pool, chunk_size);
+ chunk->next = NULL;
+ ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ *pool = (struct mempool) {
+ .state = { .free = { chunk_size - sizeof(*pool) }, .last = { chunk } },
+ .chunk_size = chunk_size,
+ .threshold = chunk_size >> 1,
+ .last_big = &pool->last_big
+ };
+ return pool;
+}
+
+static void
+mp_free_chain(struct mempool_chunk *chunk)
+{
+ while (chunk) {
+ ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ struct mempool_chunk *next = chunk->next;
+ mp_free_chunk(chunk);
+ chunk = next;
+ }
+}
+
+static void
+mp_free_big_chain(struct mempool_chunk *chunk)
+{
+ while (chunk) {
+ ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ struct mempool_chunk *next = chunk->next;
+ mp_free_big_chunk(chunk);
+ chunk = next;
+ }
+}
+
+void
+mp_delete(struct mempool *pool)
+{
+ if (pool == NULL) {
+ return;
+ }
+ DBG("Deleting mempool %p", pool);
+ mp_free_big_chain(pool->state.last[1]);
+ mp_free_chain(pool->unused);
+ mp_free_chain(pool->state.last[0]); // can contain the mempool structure
+}
+
+void
+mp_flush(struct mempool *pool)
+{
+ mp_free_big_chain(pool->state.last[1]);
+ struct mempool_chunk *chunk = pool->state.last[0], *next;
+ while (chunk) {
+ ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ if ((uint8_t *)chunk - chunk->size == (uint8_t *)pool) {
+ break;
+ }
+ next = chunk->next;
+ chunk->next = pool->unused;
+ ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ pool->unused = chunk;
+ chunk = next;
+ }
+ pool->state.last[0] = chunk;
+ if (chunk) {
+ pool->state.free[0] = chunk->size - sizeof(*pool);
+ ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ } else {
+ pool->state.free[0] = 0;
+ }
+ pool->state.last[1] = NULL;
+ pool->state.free[1] = 0;
+ pool->last_big = &pool->last_big;
+}
+
+static void
+mp_stats_chain(struct mempool_chunk *chunk, struct mempool_stats *stats, unsigned idx)
+{
+ struct mempool_chunk *next;
+ while (chunk) {
+ ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ stats->chain_size[idx] += chunk->size + sizeof(*chunk);
+ stats->chain_count[idx]++;
+ next = chunk->next;
+ ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ chunk = next;
+ }
+ stats->total_size += stats->chain_size[idx];
+}
+
+void
+mp_stats(struct mempool *pool, struct mempool_stats *stats)
+{
+ bzero(stats, sizeof(*stats));
+ mp_stats_chain(pool->state.last[0], stats, 0);
+ mp_stats_chain(pool->state.last[1], stats, 1);
+ mp_stats_chain(pool->unused, stats, 2);
+}
+
+uint64_t
+mp_total_size(struct mempool *pool)
+{
+ struct mempool_stats stats;
+ mp_stats(pool, &stats);
+ return stats.total_size;
+}
+
+static void *
+mp_alloc_internal(struct mempool *pool, unsigned size)
+{
+ struct mempool_chunk *chunk;
+ if (size <= pool->threshold) {
+ pool->idx = 0;
+ if (pool->unused) {
+ chunk = pool->unused;
+ ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ pool->unused = chunk->next;
+ } else {
+ chunk = mp_new_chunk(pool->chunk_size);
+ }
+ chunk->next = pool->state.last[0];
+ ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ pool->state.last[0] = chunk;
+ pool->state.free[0] = pool->chunk_size - size;
+ return (uint8_t *)chunk - pool->chunk_size;
+ } else if (size <= MP_SIZE_MAX) {
+ pool->idx = 1;
+ unsigned aligned = ALIGN_TO(size, CPU_STRUCT_ALIGN);
+ chunk = mp_new_big_chunk(aligned);
+ if (!chunk) {
+ return NULL;
+ }
+ chunk->next = pool->state.last[1];
+ ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk));
+ pool->state.last[1] = chunk;
+ pool->state.free[1] = aligned - size;
+ return pool->last_big = (uint8_t *)chunk - aligned;
+ } else {
+ fprintf(stderr, "Cannot allocate %u bytes from a mempool", size);
+ assert(0);
+ return NULL;
+ }
+}
+
+void *
+mp_alloc(struct mempool *pool, unsigned size)
+{
+ unsigned avail = pool->state.free[0] & ~(CPU_STRUCT_ALIGN - 1);
+ void *ptr = NULL;
+ if (size <= avail) {
+ pool->state.free[0] = avail - size;
+ ptr = (uint8_t*)pool->state.last[0] - avail;
+ } else {
+ ptr = mp_alloc_internal(pool, size);
+ }
+ ASAN_UNPOISON_MEMORY_REGION(ptr, size);
+ return ptr;
+}
+
+void *
+mp_alloc_noalign(struct mempool *pool, unsigned size)
+{
+ void *ptr = NULL;
+ if (size <= pool->state.free[0]) {
+ ptr = (uint8_t*)pool->state.last[0] - pool->state.free[0];
+ pool->state.free[0] -= size;
+ } else {
+ ptr = mp_alloc_internal(pool, size);
+ }
+ ASAN_UNPOISON_MEMORY_REGION(ptr, size);
+ return ptr;
+}
+
+void *
+mp_alloc_zero(struct mempool *pool, unsigned size)
+{
+ void *ptr = mp_alloc(pool, size);
+ bzero(ptr, size);
+ return ptr;
+}
--- /dev/null
+/*
+ * UCW Library -- Memory Pools
+ *
+ * (c) 1997--2005 Martin Mares <mj@ucw.cz>
+ * (c) 2007 Pavel Charvat <pchar@ucw.cz>
+ * (c) 2015, 2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ *
+ * This software may be freely distributed and used according to the terms
+ * of the GNU Lesser General Public License.
+ */
+
+#pragma once
+
+#include <string.h>
+#include <stdint.h>
+
+#define CPU_STRUCT_ALIGN (sizeof(void*))
+
+/***
+ * [[defs]]
+ * Definitions
+ * -----------
+ ***/
+
+/**
+ * Memory pool state (see mp_push(), ...).
+ * You should use this one as an opaque handle only, the insides are internal.
+ **/
+struct mempool_state {
+ unsigned free[2];
+ void *last[2];
+};
+
+/**
+ * Memory pool.
+ * You should use this one as an opaque handle only, the insides are internal.
+ **/
+struct mempool {
+ struct mempool_state state;
+ void *unused, *last_big;
+ unsigned chunk_size, threshold, idx;
+};
+
+struct mempool_stats { /** Mempool statistics. See mp_stats(). **/
+ uint64_t total_size; /** Real allocated size in bytes. */
+ unsigned chain_count[3]; /** Number of allocated chunks in small/big/unused chains. */
+ unsigned chain_size[3]; /** Size of allocated chunks in small/big/unused chains. */
+};
+
+/***
+ * [[basic]]
+ * Basic manipulation
+ * ------------------
+ ***/
+
+/**
+ * Initialize a given mempool structure.
+ * \p chunk_size must be in the interval `[1, UINT_MAX / 2]`.
+ * It will allocate memory by this large chunks and take
+ * memory to satisfy requests from them.
+ *
+ * Memory pools can be treated as <<trans:respools,resources>>, see <<trans:res_mempool()>>.
+ **/
+void mp_init(struct mempool *pool, unsigned chunk_size);
+
+/**
+ * Allocate and initialize a new memory pool.
+ * See \ref mp_init() for \p chunk_size limitations.
+ *
+ * The new mempool structure is allocated on the new mempool.
+ *
+ * Memory pools can be treated as <<trans:respools,resources>>, see <<trans:res_mempool()>>.
+ **/
+struct mempool *mp_new(unsigned chunk_size);
+
+/**
+ * Cleanup mempool initialized by mp_init or mp_new.
+ * Frees all the memory allocated by this mempool and,
+ * if created by \ref mp_new(), the \p pool itself.
+ **/
+void mp_delete(struct mempool *pool);
+
+/**
+ * Frees all data on a memory pool, but leaves it working.
+ * It can keep some of the chunks allocated to serve
+ * further allocation requests. Leaves the \p pool alive,
+ * even if it was created with \ref mp_new().
+ **/
+void mp_flush(struct mempool *pool);
+
+/**
+ * Compute some statistics for debug purposes.
+ * See the definition of the <<struct_mempool_stats,mempool_stats structure>>.
+ **/
+void mp_stats(struct mempool *pool, struct mempool_stats *stats);
+uint64_t mp_total_size(struct mempool *pool); /** How many bytes were allocated by the pool. **/
+
+/***
+ * [[alloc]]
+ * Allocation routines
+ * -------------------
+ ***/
+
+/**
+ * The function allocates new \p size bytes on a given memory pool.
+ * If the \p size is zero, the resulting pointer is undefined,
+ * but it may be safely reallocated or used as the parameter
+ * to other functions below.
+ *
+ * The resulting pointer is always aligned to a multiple of
+ * `CPU_STRUCT_ALIGN` bytes and this condition remains true also
+ * after future reallocations.
+ **/
+void *mp_alloc(struct mempool *pool, unsigned size);
+
+/**
+ * The same as \ref mp_alloc(), but the result may be unaligned.
+ **/
+void *mp_alloc_noalign(struct mempool *pool, unsigned size);
+
+/**
+ * The same as \ref mp_alloc(), but fills the newly allocated memory with zeroes.
+ **/
+void *mp_alloc_zero(struct mempool *pool, unsigned size);
--- /dev/null
+#include <stdio.h>
+#include <histedit.h>
+#include <string.h>
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#include "lib/generic/array.h"
+#include <contrib/ccan/asprintf/asprintf.h>
+
+#include "commands.h"
+#include "interactive.h"
+#include "process.h"
+#include "deps/lookup.h"
+
+
+static void cmds_lookup(EditLine *el, const char *str, size_t str_len)
+{
+ lookup_t lookup;
+ int ret = lookup_init(&lookup);
+ if (ret != CLI_EOK) {
+ return;
+ }
+
+ /* Fill the lookup with command names of static cmds. */
+ for (const cmd_help_t *cmd_help = cmd_table; cmd_help->name != NULL; cmd_help++) {
+ ret = lookup_insert(&lookup, cmd_help->name, NULL);
+ if (ret != CLI_EOK) {
+ goto cmds_lookup_finish;
+ }
+ }
+
+ /* Fill the lookup with command names of dynamic cmds. */
+ dynarray_foreach(cmd_help, cmd_help_t *, i, dyn_cmd_help_table) {
+ cmd_help_t *cmd_help = *i;
+ ret = lookup_insert(&lookup, cmd_help->name, NULL);
+ if (ret != CLI_EOK) {
+ goto cmds_lookup_finish;
+ }
+ }
+
+ lookup_complete(&lookup, str, str_len, el, true);
+
+ cmds_lookup_finish:
+ lookup_deinit(&lookup);
+}
+
+static unsigned char complete(EditLine *el, int ch)
+{
+ int argc, token, pos;
+ const char **argv;
+
+ const LineInfo *li = el_line(el);
+ Tokenizer *tok = tok_init(NULL);
+
+ /* Parse the line. */
+ int ret = tok_line(tok, li, &argc, &argv, &token, &pos);
+ if (ret != 0) {
+ goto complete_exit;
+ }
+
+ /* Show possible commands. */
+ if (argc == 0) {
+ print_commands(NULL);
+ goto complete_exit;
+ }
+
+ /* Complete the command name. */
+ if (token == 0) {
+ cmds_lookup(el, argv[0], pos);
+ goto complete_exit;
+ }
+
+ /* Find the command descriptor. */
+ const cmd_desc_t *desc = cmd_table;
+ while (desc->name != NULL && strcmp(desc->name, argv[0]) != 0) {
+ desc++;
+ }
+ if (desc->name == NULL) {
+ goto complete_exit;
+ }
+
+ complete_exit:
+ tok_reset(tok);
+ tok_end(tok);
+
+ return CC_REDISPLAY;
+}
+
+static char *prompt(EditLine *el)
+{
+ return PROGRAM_NAME"> ";
+}
+
+int interactive_loop(params_t *process_params)
+{
+ char *hist_file = NULL;
+ const char *home = getenv("HOME");
+ if (home != NULL) {
+ asprintf(&hist_file, "%s/"HISTORY_FILE, home);
+ }
+ if (hist_file == NULL) {
+ printf("failed to get home directory");
+ }
+
+ EditLine *el = el_init(PROGRAM_NAME, stdin, stdout, stderr);
+ if (el == NULL) {
+ printf("interactive mode not available");
+ free(hist_file);
+ return 1;
+ }
+
+ History *hist = history_init();
+ if (hist == NULL) {
+ printf("interactive mode not available");
+ el_end(el);
+ free(hist_file);
+ return 1;
+ }
+
+ HistEvent hev = { 0 };
+ history(hist, &hev, H_SETSIZE, 100);
+ el_set(el, EL_HIST, history, hist);
+ history(hist, &hev, H_LOAD, hist_file);
+
+ el_set(el, EL_TERMINAL, NULL);
+ el_set(el, EL_EDITOR, "emacs");
+ el_set(el, EL_PROMPT, prompt);
+ el_set(el, EL_SIGNAL, 1);
+ el_source(el, NULL);
+
+ el_set(el, EL_ADDFN, PROGRAM_NAME"-complete",
+ "Perform "PROGRAM_NAME" completion.", complete);
+ el_set(el, EL_BIND, "^I", PROGRAM_NAME"-complete", NULL);
+
+ int count;
+ const char *line;
+ while ((line = el_gets(el, &count)) != NULL && count > 0) {
+ history(hist, &hev, H_ENTER, line);
+
+ Tokenizer *tok = tok_init(NULL);
+
+ /* Tokenize the current line. */
+ int argc;
+ const char **argv;
+ const LineInfo *li = el_line(el);
+ int ret = tok_line(tok, li, &argc, &argv, NULL, NULL);
+ if (ret != 0) {
+ continue;
+ }
+
+ /* Process the command. */
+ ret = process_cmd(argc, argv, process_params);
+
+ history(hist, &hev, H_SAVE, hist_file);
+ tok_reset(tok);
+ tok_end(tok);
+
+ /* Check for the exit command. */
+ if (ret == CLI_EXIT) {
+ break;
+ }
+ }
+
+ history_end(hist);
+ free(hist_file);
+
+ el_end(el);
+
+ return 0;
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <process.h>
+
+
+/** CLI interactive loop */
+int interactive_loop(params_t *params);
\ No newline at end of file
+#include <getopt.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "kresconfig.h"
+#include "interactive.h"
+#include "process.h"
+#include "commands.h"
+
+
+params_t params = {
+ .timeout = SYSREPO_TIMEOUT * 1000,
+};
+
+sysrepo_ctx_t *sysrepo_ctx;
+sysrepo_ctx_t sysrepo_ctx_value = {
+ .connection = NULL,
+ .session = NULL
+};
+
+static void print_help(void)
+{
+ print_version(NULL);
+
+ printf("\nUsage:\n"
+ " %s [parameters] <command> [command-arguments]\n"
+ "\n"
+ "Parameters:\n"
+ " -t, --timeout <sec> "SPACE"Timeout for sysrepo operations.\n"
+ " "SPACE" (default %d seconds)\n"
+ " -h, --help "SPACE"Print the program help.\n"
+ " -V, --version "SPACE"Print the program version.\n",
+ PROGRAM_NAME, SYSREPO_TIMEOUT);
+
+ print_commands(NULL);
+}
+
int main(int argc, char *argv[])
{
- return 0;
+ /* Long options. */
+ struct option opts[] = {
+ { "timeout", required_argument, NULL, 't' },
+ { "help", no_argument, NULL, 'h' },
+ { "version", no_argument, NULL, 'V' },
+ { NULL }
+ };
+
+ /* Init sysrepo connection */
+ sysrepo_ctx = &sysrepo_ctx_value;
+ int ret = sr_connect(0, &sysrepo_ctx->connection);
+ if (ret){
+ printf("[kresctl] failed to connect to sysrepo: %s\n",
+ sr_strerror(ret));
+ goto sr_cleanup;
+ }
+
+ /* Create dynamic commands table */
+ ret = create_cmd_table(sysrepo_ctx->connection);
+ if (ret){
+ printf("[kresctl] failed to create commands table\n");
+ goto sr_cleanup;
+ }
+
+ /* Parse command line parameters */
+ int opt = 0;
+ while ((opt = getopt_long(argc, argv, "+t:hV", opts, NULL)) != -1) {
+ switch (opt) {
+ case 't':
+ params.timeout = atoi(optarg);
+ if (params.timeout < 1) {
+ printf("[kresctl] error '-t' requires a positive"
+ " number, not '%s'\n", optarg);
+ return EXIT_FAILURE;
+ }
+ /* Convert to milliseconds. */
+ params.timeout *= 1000;
+ break;
+ case 'h':
+ print_help();
+ return EXIT_SUCCESS;
+ case 'V':
+ print_version(NULL);
+ return EXIT_SUCCESS;
+ default:
+ print_help();
+ return EXIT_FAILURE;
+ }
+ }
+
+ if (argc - optind < 1) {
+ /* start interactive loop */
+ ret = interactive_loop(¶ms);
+ } else {
+ /*
+ * Session with RUNNING datastore
+ * needs to be created here, because there
+ * is no interactive loop to create
+ * transaction with candidate datastore.
+ */
+ ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sysrepo_ctx->session);
+ if (ret) {
+ printf("failed to start sysrepo session, %s\n", sr_strerror(ret));
+ goto cleanup;
+ }
+ /* execute commands added from terminal */
+ if (!ret) ret = process_cmd(argc - optind, (const char **)argv + optind, ¶ms);
+
+ /* Stop sysrepo session. */
+ ret = sr_session_stop(sysrepo_ctx->session);
+ if (ret) {
+ printf("failed to stop sysrepo session, %s\n", sr_strerror(ret));
+ }
+ }
+
+cleanup:
+ /* free all dynamic tables */
+ destroy_cmd_table();
+
+sr_cleanup:
+ sr_disconnect(sysrepo_ctx->connection);
+
+ return (ret == CLI_EOK) ? EXIT_SUCCESS : EXIT_FAILURE;
}
\ No newline at end of file
kresctl_src = files([
'main.c',
+ 'process.c',
+ 'interactive.c',
+ 'commands.c',
+ 'conf_file.c',
+ 'deps/lookup.c',
+ 'deps/mempattern.c',
+ 'deps/string.c',
+ 'deps/mempattern.c',
])
c_src_lint += kresctl_src
-if build_sysrepo
- kresc = executable(
+message('--- kresctl dependencies ---')
+libedit = dependency('libedit', required: false)
+
+if build_sysrepo and libedit.found()
+ kresctl = executable(
'kresctl',
kresctl_src,
sysrepo_common_src,
dependencies: [
contrib_dep,
libkres_dep,
+ libedit,
libyang,
libsysrepo,
libknot,
- luajit_inc,
],
install: true,
install_dir: get_option('sbindir'),
--- /dev/null
+#include <errno.h>
+#include <string.h>
+
+#include "contrib/dynarray.h"
+#include "process.h"
+#include "commands.h"
+
+
+static const cmd_desc_t *get_cmd_desc(const char *command)
+{
+ /* Try to find requested command in buid-in commands. */
+ const cmd_desc_t *desc = cmd_table;
+ while (desc->name != NULL) {
+ if (strcmp(desc->name, command) == 0) {
+ break;
+ }
+ desc++;
+ }
+
+ /* Try to find requested command in created commands. */
+ dynarray_foreach(cmd, cmd_desc_t *, i, dyn_cmd_table) {
+ cmd_desc_t *dyn_desc = *i;
+ if (strcmp(dyn_desc->name, command) == 0) {
+ desc = dyn_desc;
+ break;
+ }
+ }
+
+ if (desc->name == NULL) {
+ printf("invalid command '%s'\n", command);
+ return NULL;
+ }
+
+ return desc;
+}
+
+int process_cmd(int argc, const char **argv, params_t *params)
+{
+ if (argc == 0) {
+ return ENOTSUP;
+ }
+
+ /* Check the command name. */
+ const cmd_desc_t *desc = get_cmd_desc(argv[0]);
+ if (desc == NULL) {
+ return ENOENT;
+ }
+
+ /* Check for program exit. */
+ if (desc->fcn == NULL) {
+ return CLI_EXIT;
+ }
+
+ /* Prepare command arguments. */
+ cmd_args_t args = {
+ .desc = desc,
+ .argc = argc - 1,
+ .argv = argv + 1,
+ .timeout = params->timeout,
+ };
+
+ /* Execute the command. */
+ int ret = desc->fcn(&args);
+
+ return ret;
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <stdio.h>
+#include <stdbool.h>
+#include <sysrepo.h>
+
+#include "kresconfig.h"
+#include "process.h"
+#include "commands.h"
+
+/* CLI globals */
+#define PROJECT_NAME "Knot Resolver"
+#define PROGRAM_NAME "kresctl"
+#define PROGRAM_DESC "control/administration tool"
+#define HISTORY_FILE ".kresctl_history"
+#define SPACE " "
+/* default values */
+#define SYSREPO_TIMEOUT 10
+/* return codes */
+#define CLI_EOK 0
+#define CLI_ERR 1
+#define CLI_ECMD 2
+#define CLI_EXIT -10
+
+
+/* CLI parameters. */
+typedef struct {
+ int timeout;
+ int max_depth;
+} params_t;
+
+/* CLI context */
+typedef struct {
+ sr_conn_ctx_t *connection;
+ sr_session_ctx_t *session;
+} sysrepo_ctx_t;
+
+extern sysrepo_ctx_t *sysrepo_ctx;
+
+int process_cmd(int argc, const char **argv, params_t *params);
build_utils = get_option('utils') != 'disabled'
subdir('kresctl')
-subdir('kres_watcher')
+subdir('watcher')
subdir('client')
subdir('cache_gc')
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <lua.h>
+
+/** Make all the bindings accessible from the lua state,
+ * .i.e. define those lua tables. */
+void kr_bindings_register(lua_State *L);
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "bindings/impl.h"
+
+#include "worker.h"
+#include "zimport.h"
+
+/** @internal return cache, or throw lua error if not open */
+struct kr_cache * cache_assert_open(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ struct kr_cache *cache = &engine->resolver.cache;
+ assert(cache);
+ if (!cache || !kr_cache_is_open(cache))
+ lua_error_p(L, "no cache is open yet, use cache.open() or cache.size, etc.");
+ return cache;
+}
+
+/** Return available cached backends. */
+static int cache_backends(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+
+ lua_newtable(L);
+ for (unsigned i = 0; i < engine->backends.len; ++i) {
+ const struct kr_cdb_api *api = engine->backends.at[i];
+ lua_pushboolean(L, api == engine->resolver.cache.api);
+ lua_setfield(L, -2, api->name);
+ }
+ return 1;
+}
+
+/** Return number of cached records. */
+static int cache_count(lua_State *L)
+{
+ struct kr_cache *cache = cache_assert_open(L);
+
+ int count = cache->api->count(cache->db, &cache->stats);
+ if (count >= 0) {
+ /* First key is a version counter, omit it if nonempty. */
+ lua_pushinteger(L, count ? count - 1 : 0);
+ return 1;
+ }
+ return 0;
+}
+
+/** Return time of last checkpoint, or re-set it if passed `true`. */
+static int cache_checkpoint(lua_State *L)
+{
+ struct kr_cache *cache = cache_assert_open(L);
+
+ if (lua_gettop(L) == 0) { /* Return the current value. */
+ lua_newtable(L);
+ lua_pushnumber(L, cache->checkpoint_monotime);
+ lua_setfield(L, -2, "monotime");
+ lua_newtable(L);
+ lua_pushnumber(L, cache->checkpoint_walltime.tv_sec);
+ lua_setfield(L, -2, "sec");
+ lua_pushnumber(L, cache->checkpoint_walltime.tv_usec);
+ lua_setfield(L, -2, "usec");
+ lua_setfield(L, -2, "walltime");
+ return 1;
+ }
+
+ if (lua_gettop(L) != 1 || !lua_isboolean(L, 1) || !lua_toboolean(L, 1))
+ lua_error_p(L, "cache.checkpoint() takes no parameters or a true value");
+
+ kr_cache_make_checkpoint(cache);
+ return 1;
+}
+
+/** Return cache statistics. */
+static int cache_stats(lua_State *L)
+{
+ struct kr_cache *cache = cache_assert_open(L);
+ lua_newtable(L);
+#define add_stat(name) \
+ lua_pushinteger(L, (cache->stats.name)); \
+ lua_setfield(L, -2, #name)
+ add_stat(open);
+ add_stat(close);
+ add_stat(count);
+ add_stat(clear);
+ add_stat(commit);
+ add_stat(read);
+ add_stat(read_miss);
+ add_stat(write);
+ add_stat(remove);
+ add_stat(remove_miss);
+ add_stat(match);
+ add_stat(match_miss);
+ add_stat(read_leq);
+ add_stat(read_leq_miss);
+#undef add_stat
+
+ return 1;
+}
+
+static const struct kr_cdb_api *cache_select(struct engine *engine, const char **conf)
+{
+ /* Return default backend */
+ if (*conf == NULL || !strstr(*conf, "://")) {
+ return engine->backends.at[0];
+ }
+
+ /* Find storage backend from config prefix */
+ for (unsigned i = 0; i < engine->backends.len; ++i) {
+ const struct kr_cdb_api *api = engine->backends.at[i];
+ if (strncmp(*conf, api->name, strlen(api->name)) == 0) {
+ *conf += strlen(api->name) + strlen("://");
+ return api;
+ }
+ }
+
+ return NULL;
+}
+
+static int cache_max_ttl(lua_State *L)
+{
+ struct kr_cache *cache = cache_assert_open(L);
+
+ int n = lua_gettop(L);
+ if (n > 0) {
+ if (!lua_isnumber(L, 1) || n > 1)
+ lua_error_p(L, "expected 'max_ttl(number ttl)'");
+ uint32_t min = cache->ttl_min;
+ int64_t ttl = lua_tointeger(L, 1);
+ if (ttl < 1 || ttl < min || ttl > UINT32_MAX) {
+ lua_error_p(L,
+ "max_ttl must be larger than minimum TTL, and in range <1, "
+ STR(UINT32_MAX) ">'");
+ }
+ cache->ttl_max = ttl;
+ }
+ lua_pushinteger(L, cache->ttl_max);
+ return 1;
+}
+
+
+static int cache_min_ttl(lua_State *L)
+{
+ struct kr_cache *cache = cache_assert_open(L);
+
+ int n = lua_gettop(L);
+ if (n > 0) {
+ if (!lua_isnumber(L, 1))
+ lua_error_p(L, "expected 'min_ttl(number ttl)'");
+ uint32_t max = cache->ttl_max;
+ int64_t ttl = lua_tointeger(L, 1);
+ if (ttl < 0 || ttl > max || ttl > UINT32_MAX) {
+ lua_error_p(L,
+ "min_ttl must be smaller than maximum TTL, and in range <0, "
+ STR(UINT32_MAX) ">'");
+ }
+ cache->ttl_min = ttl;
+ }
+ lua_pushinteger(L, cache->ttl_min);
+ return 1;
+}
+
+/** Open cache */
+static int cache_open(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 1 || !lua_isnumber(L, 1))
+ lua_error_p(L, "expected 'open(number max_size, string config = \"\")'");
+
+ /* Select cache storage backend */
+ struct engine *engine = engine_luaget(L);
+
+ lua_Integer csize_lua = lua_tointeger(L, 1);
+ if (!(csize_lua >= 8192 && csize_lua < SIZE_MAX)) { /* min. is basically arbitrary */
+ lua_error_p(L, "invalid cache size specified, it must be in range <8192, "
+ STR(SIZE_MAX) ">");
+ }
+ size_t cache_size = csize_lua;
+
+ const char *conf = n > 1 ? lua_tostring(L, 2) : NULL;
+ const char *uri = conf;
+ const struct kr_cdb_api *api = cache_select(engine, &conf);
+ if (!api)
+ lua_error_p(L, "unsupported cache backend");
+
+ /* Close if already open */
+ kr_cache_close(&engine->resolver.cache);
+
+ /* Reopen cache */
+ struct kr_cdb_opts opts = {
+ (conf && strlen(conf)) ? conf : ".",
+ cache_size
+ };
+ int ret = kr_cache_open(&engine->resolver.cache, api, &opts, engine->pool);
+ if (ret != 0) {
+ char cwd[PATH_MAX];
+ get_workdir(cwd, sizeof(cwd));
+ return luaL_error(L, "can't open cache path '%s'; working directory '%s'", opts.path, cwd);
+ }
+
+ /* Store current configuration */
+ lua_getglobal(L, "cache");
+ lua_pushstring(L, "current_size");
+ lua_pushnumber(L, cache_size);
+ lua_rawset(L, -3);
+ lua_pushstring(L, "current_storage");
+ lua_pushstring(L, uri);
+ lua_rawset(L, -3);
+ lua_pop(L, 1);
+
+ lua_pushboolean(L, 1);
+ return 1;
+}
+
+static int cache_close(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ struct kr_cache *cache = &engine->resolver.cache;
+ if (!kr_cache_is_open(cache)) {
+ return 0;
+ }
+
+ kr_cache_close(cache);
+ lua_getglobal(L, "cache");
+ lua_pushstring(L, "current_size");
+ lua_pushnumber(L, 0);
+ lua_rawset(L, -3);
+ lua_pop(L, 1);
+ lua_pushboolean(L, 1);
+ return 1;
+}
+
+#if 0
+/** @internal Prefix walk. */
+static int cache_prefixed(struct kr_cache *cache, const char *prefix, bool exact_name,
+ knot_db_val_t keyval[][2], int maxcount)
+{
+ /* Convert to domain name */
+ uint8_t buf[KNOT_DNAME_MAXLEN];
+ if (!knot_dname_from_str(buf, prefix, sizeof(buf))) {
+ return kr_error(EINVAL);
+ }
+ /* Start prefix search */
+ return kr_cache_match(cache, buf, exact_name, keyval, maxcount);
+}
+#endif
+
+/** Clear everything. */
+static int cache_clear_everything(lua_State *L)
+{
+ struct kr_cache *cache = cache_assert_open(L);
+
+ /* Clear records and packets. */
+ int ret = kr_cache_clear(cache);
+ lua_error_maybe(L, ret);
+
+ /* Clear reputation tables */
+ struct engine *engine = engine_luaget(L);
+ lru_reset(engine->resolver.cache_rtt);
+ lru_reset(engine->resolver.cache_rep);
+ lru_reset(engine->resolver.cache_cookie);
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+#if 0
+/** @internal Dump cache key into table on Lua stack. */
+static void cache_dump(lua_State *L, knot_db_val_t keyval[])
+{
+ knot_dname_t dname[KNOT_DNAME_MAXLEN];
+ char name[KNOT_DNAME_TXT_MAXLEN];
+ uint16_t type;
+
+ int ret = kr_unpack_cache_key(keyval[0], dname, &type);
+ if (ret < 0) {
+ return;
+ }
+
+ ret = !knot_dname_to_str(name, dname, sizeof(name));
+ assert(!ret);
+ if (ret) return;
+
+ /* If name typemap doesn't exist yet, create it */
+ lua_getfield(L, -1, name);
+ if (lua_isnil(L, -1)) {
+ lua_pop(L, 1);
+ lua_newtable(L);
+ }
+ /* Append to typemap */
+ char type_buf[KR_RRTYPE_STR_MAXLEN] = { '\0' };
+ knot_rrtype_to_string(type, type_buf, sizeof(type_buf));
+ lua_pushboolean(L, true);
+ lua_setfield(L, -2, type_buf);
+ /* Set name typemap */
+ lua_setfield(L, -2, name);
+}
+
+/** Query cached records. TODO: fix caveats in ./README.rst documentation? */
+static int cache_get(lua_State *L)
+{
+ //struct kr_cache *cache = cache_assert_open(L); // to be fixed soon
+
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "expected 'cache.get(string key)'");
+
+ /* Retrieve set of keys */
+ const char *prefix = lua_tostring(L, 1);
+ knot_db_val_t keyval[100][2];
+ int ret = cache_prefixed(cache, prefix, false/*FIXME*/, keyval, 100);
+ lua_error_maybe(L, ret);
+ /* Format output */
+ lua_newtable(L);
+ for (int i = 0; i < ret; ++i) {
+ cache_dump(L, keyval[i]);
+ }
+ return 1;
+}
+#endif
+static int cache_get(lua_State *L)
+{
+ lua_error_maybe(L, ENOSYS);
+ return kr_error(ENOSYS); /* doesn't happen */
+}
+
+/** Set time interval for cleaning rtt cache.
+ * Servers with score >= KR_NS_TIMEOUT will be cleaned after
+ * this interval ended up, so that they will be able to participate
+ * in NS elections again. */
+static int cache_ns_tout(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ struct kr_context *ctx = &engine->resolver;
+
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 1) {
+ lua_pushinteger(L, ctx->cache_rtt_tout_retry_interval);
+ return 1;
+ }
+
+ if (!lua_isnumber(L, 1))
+ lua_error_p(L, "expected 'cache.ns_tout(interval in ms)'");
+
+ lua_Integer interval_lua = lua_tointeger(L, 1);
+ if (!(interval_lua > 0 && interval_lua < UINT_MAX)) {
+ lua_error_p(L, "invalid interval specified, it must be in range > 0, < "
+ STR(UINT_MAX));
+ }
+
+ ctx->cache_rtt_tout_retry_interval = interval_lua;
+ lua_pushinteger(L, ctx->cache_rtt_tout_retry_interval);
+ return 1;
+}
+
+/** Zone import completion callback.
+ * Deallocates zone import context. */
+static void cache_zone_import_cb(int state, void *param)
+{
+ assert (param);
+ (void)state;
+ struct worker_ctx *worker = (struct worker_ctx *)param;
+ assert (worker->z_import);
+ zi_free(worker->z_import);
+ worker->z_import = NULL;
+}
+
+/** Import zone from file. */
+static int cache_zone_import(lua_State *L)
+{
+ int ret = -1;
+ char msg[128];
+
+ struct worker_ctx *worker = the_worker;
+ if (!worker) {
+ strncpy(msg, "internal error, empty worker pointer", sizeof(msg));
+ goto finish;
+ }
+
+ if (worker->z_import && zi_import_started(worker->z_import)) {
+ strncpy(msg, "import already started", sizeof(msg));
+ goto finish;
+ }
+
+ (void)cache_assert_open(L); /* just check it in advance */
+
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 1 || !lua_isstring(L, 1)) {
+ strncpy(msg, "expected 'cache.zone_import(path to zone file)'", sizeof(msg));
+ goto finish;
+ }
+
+ /* Parse zone file */
+ const char *zone_file = lua_tostring(L, 1);
+
+ const char *default_origin = NULL; /* TODO */
+ uint16_t default_rclass = 1;
+ uint32_t default_ttl = 0;
+
+ if (worker->z_import == NULL) {
+ worker->z_import = zi_allocate(worker, cache_zone_import_cb, worker);
+ if (worker->z_import == NULL) {
+ strncpy(msg, "can't allocate zone import context", sizeof(msg));
+ goto finish;
+ }
+ }
+
+ ret = zi_zone_import(worker->z_import, zone_file, default_origin,
+ default_rclass, default_ttl);
+
+ lua_newtable(L);
+ if (ret == 0) {
+ strncpy(msg, "zone file successfully parsed, import started", sizeof(msg));
+ } else if (ret == 1) {
+ strncpy(msg, "TA not found", sizeof(msg));
+ } else {
+ strncpy(msg, "error parsing zone file", sizeof(msg));
+ }
+
+finish:
+ msg[sizeof(msg) - 1] = 0;
+ lua_newtable(L);
+ lua_pushstring(L, msg);
+ lua_setfield(L, -2, "msg");
+ lua_pushnumber(L, ret);
+ lua_setfield(L, -2, "code");
+
+ return 1;
+}
+
+int kr_bindings_cache(lua_State *L)
+{
+ static const luaL_Reg lib[] = {
+ { "backends", cache_backends },
+ { "count", cache_count },
+ { "stats", cache_stats },
+ { "checkpoint", cache_checkpoint },
+ { "open", cache_open },
+ { "close", cache_close },
+ { "clear_everything", cache_clear_everything },
+ { "get", cache_get },
+ { "max_ttl", cache_max_ttl },
+ { "min_ttl", cache_min_ttl },
+ { "ns_tout", cache_ns_tout },
+ { "zone_import", cache_zone_import },
+ { NULL, NULL }
+ };
+
+ luaL_register(L, "cache", lib);
+ return 1;
+}
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "bindings/impl.h"
+
+#include "worker.h"
+
+#include <unistd.h>
+#include <uv.h>
+
+static void event_free(uv_timer_t *timer)
+{
+ struct worker_ctx *worker = timer->loop->data;
+ lua_State *L = worker->engine->L;
+ int ref = (intptr_t) timer->data;
+ luaL_unref(L, LUA_REGISTRYINDEX, ref);
+ free(timer);
+}
+
+static void event_callback(uv_timer_t *timer)
+{
+ struct worker_ctx *worker = timer->loop->data;
+ lua_State *L = worker->engine->L;
+
+ /* Retrieve callback and execute */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, (intptr_t) timer->data);
+ lua_rawgeti(L, -1, 1);
+ lua_pushinteger(L, (intptr_t) timer->data);
+ int ret = execute_callback(L, 1);
+ /* Free callback if not recurrent or an error */
+ if (ret != 0 || (uv_timer_get_repeat(timer) == 0 && uv_is_active((uv_handle_t *)timer) == 0)) {
+ if (!uv_is_closing((uv_handle_t *)timer)) {
+ uv_close((uv_handle_t *)timer, (uv_close_cb) event_free);
+ }
+ }
+}
+
+static void event_fdcallback(uv_poll_t* handle, int status, int events)
+{
+ struct worker_ctx *worker = handle->loop->data;
+ lua_State *L = worker->engine->L;
+
+ /* Retrieve callback and execute */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, (intptr_t) handle->data);
+ lua_rawgeti(L, -1, 1);
+ lua_pushinteger(L, (intptr_t) handle->data);
+ lua_pushinteger(L, status);
+ lua_pushinteger(L, events);
+ int ret = execute_callback(L, 3);
+ /* Free callback if not recurrent or an error */
+ if (ret != 0) {
+ if (!uv_is_closing((uv_handle_t *)handle)) {
+ uv_close((uv_handle_t *)handle, (uv_close_cb) event_free);
+ }
+ }
+}
+
+static int event_sched(lua_State *L, unsigned timeout, unsigned repeat)
+{
+ uv_timer_t *timer = malloc(sizeof(*timer));
+ if (!timer)
+ lua_error_p(L, "out of memory");
+
+ /* Start timer with the reference */
+ uv_loop_t *loop = uv_default_loop();
+ uv_timer_init(loop, timer);
+ int ret = uv_timer_start(timer, event_callback, timeout, repeat);
+ if (ret != 0) {
+ free(timer);
+ lua_error_p(L, "couldn't start the event");
+ }
+
+ /* Save callback and timer in registry */
+ lua_newtable(L);
+ lua_pushvalue(L, 2);
+ lua_rawseti(L, -2, 1);
+ lua_pushpointer(L, timer);
+ lua_rawseti(L, -2, 2);
+ int ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ /* Save reference to the timer */
+ timer->data = (void *) (intptr_t)ref;
+ lua_pushinteger(L, ref);
+ return 1;
+}
+
+static int event_after(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 2 || !lua_isnumber(L, 1) || !lua_isfunction(L, 2))
+ lua_error_p(L, "expected 'after(number timeout, function)'");
+
+ return event_sched(L, lua_tointeger(L, 1), 0);
+}
+
+static int event_recurrent(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 2 || !lua_isnumber(L, 1) || !lua_isfunction(L, 2))
+ lua_error_p(L, "expected 'recurrent(number interval, function)'");
+
+ return event_sched(L, 0, lua_tointeger(L, 1));
+}
+
+static int event_cancel(lua_State *L)
+{
+ int n = lua_gettop(L);
+ if (n < 1 || !lua_isnumber(L, 1))
+ lua_error_p(L, "expected 'cancel(number event)'");
+
+ /* Fetch event if it exists */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, lua_tointeger(L, 1));
+ bool ok = lua_istable(L, -1);
+
+ /* Close the timer */
+ uv_handle_t **timer_pp = NULL;
+ if (ok) {
+ lua_rawgeti(L, -1, 2);
+ timer_pp = lua_touserdata(L, -1);
+ ok = timer_pp && *timer_pp;
+ /* That have been sufficient safety checks, hopefully. */
+ }
+ if (ok && !uv_is_closing(*timer_pp)) {
+ uv_close(*timer_pp, (uv_close_cb)event_free);
+ }
+ lua_pushboolean(L, ok);
+ return 1;
+}
+
+static int event_reschedule(lua_State *L)
+{
+ int n = lua_gettop(L);
+ if (n < 2 || !lua_isnumber(L, 1) || !lua_isnumber(L, 2))
+ lua_error_p(L, "expected 'reschedule(number event, number timeout)'");
+
+ /* Fetch event if it exists */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, lua_tointeger(L, 1));
+ bool ok = lua_istable(L, -1);
+
+ /* Reschedule the timer */
+ uv_handle_t **timer_pp = NULL;
+ if (ok) {
+ lua_rawgeti(L, -1, 2);
+ timer_pp = lua_touserdata(L, -1);
+ ok = timer_pp && *timer_pp;
+ /* That have been sufficient safety checks, hopefully. */
+ }
+ if (ok && !uv_is_closing(*timer_pp)) {
+ int ret = uv_timer_start((uv_timer_t *)*timer_pp,
+ event_callback, lua_tointeger(L, 2), 0);
+ if (ret != 0) {
+ uv_close(*timer_pp, (uv_close_cb)event_free);
+ ok = false;
+ }
+ }
+ lua_pushboolean(L, ok);
+ return 1;
+}
+
+static int event_fdwatch(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 2 || !lua_isnumber(L, 1) || !lua_isfunction(L, 2))
+ lua_error_p(L, "expected 'socket(number fd, function)'");
+
+ uv_poll_t *handle = malloc(sizeof(*handle));
+ if (!handle)
+ lua_error_p(L, "out of memory");
+
+ /* Start timer with the reference */
+ int sock = lua_tointeger(L, 1);
+ uv_loop_t *loop = uv_default_loop();
+ int ret = uv_poll_init(loop, handle, sock);
+ if (ret == 0)
+ ret = uv_poll_start(handle, UV_READABLE, event_fdcallback);
+ if (ret != 0) {
+ free(handle);
+ lua_error_p(L, "couldn't start event poller");
+ }
+
+ /* Save callback and timer in registry */
+ lua_newtable(L);
+ lua_pushvalue(L, 2);
+ lua_rawseti(L, -2, 1);
+ lua_pushpointer(L, handle);
+ lua_rawseti(L, -2, 2);
+ int ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ /* Save reference to the timer */
+ handle->data = (void *) (intptr_t)ref;
+ lua_pushinteger(L, ref);
+ return 1;
+}
+
+int kr_bindings_event(lua_State *L)
+{
+ static const luaL_Reg lib[] = {
+ { "after", event_after },
+ { "recurrent", event_recurrent },
+ { "cancel", event_cancel },
+ { "socket", event_fdwatch },
+ { "reschedule", event_reschedule },
+ { NULL, NULL }
+ };
+
+ luaL_register(L, "event", lib);
+ return 1;
+}
+
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <lua.h>
+#include <lauxlib.h>
+#include <string.h>
+
+
+const char * lua_table_checkindices(lua_State *L, const char *keys[])
+{
+ /* Iterate over table at the top of the stack.
+ * http://www.lua.org/manual/5.1/manual.html#lua_next */
+ for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+ lua_pop(L, 1); /* we don't need the value */
+ /* We need to copy the key, as _tostring() confuses _next().
+ * https://www.lua.org/manual/5.1/manual.html#lua_tolstring */
+ lua_pushvalue(L, -1);
+ const char *key = lua_tostring(L, -1);
+ if (!key)
+ return "<NON-STRING_INDEX>";
+ for (const char **k = keys; ; ++k) {
+ if (*k == NULL)
+ return key;
+ if (strcmp(*k, key) == 0)
+ break;
+ }
+ }
+ return NULL;
+}
+
+
+/* Each of these just creates the correspondingly named lua table of functions. */
+int kr_bindings_cache (lua_State *L); /* ./cache.c */
+int kr_bindings_event (lua_State *L); /* ./event.c */
+int kr_bindings_modules (lua_State *L); /* ./modules.c */
+int kr_bindings_net (lua_State *L); /* ./net.c */
+int kr_bindings_worker (lua_State *L); /* ./worker.c */
+
+void kr_bindings_register(lua_State *L)
+{
+ kr_bindings_cache(L);
+ kr_bindings_event(L);
+ kr_bindings_modules(L);
+ kr_bindings_net(L);
+ kr_bindings_worker(L);
+}
+
+void lua_error_p(lua_State *L, const char *fmt, ...)
+{
+ /* Add a stack trace and throw the result as a lua error. */
+ luaL_traceback(L, L, "error occured here (config filename:lineno is at the bottom, if config is involved):", 0);
+ /* Push formatted custom message, prepended with "ERROR: ". */
+ lua_pushliteral(L, "\nERROR: ");
+ {
+ va_list args;
+ va_start(args, fmt);
+ lua_pushvfstring(L, fmt, args);
+ va_end(args);
+ }
+ lua_concat(L, 3);
+ lua_error(L);
+ /* TODO: we might construct a little more friendly trace by using luaL_where().
+ * In particular, in case the error happens in a function that was called
+ * directly from a config file (the most common case), there isn't much need
+ * to format the trace in this heavy way. */
+}
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include "engine.h"
+
+#include <lua.h>
+#include <lauxlib.h>
+
+/** Useful to stringify #defines into error strings. */
+#define STR(s) STRINGIFY_TOKEN(s)
+#define STRINGIFY_TOKEN(s) #s
+
+
+/** Check lua table at the top of the stack for allowed keys.
+ * \param keys NULL-terminated array of 0-terminated strings
+ * \return NULL if passed or the offending string (pushed on top of lua stack)
+ * \note Future work: if non-NULL is returned, there's extra stuff on the lua stack.
+ * \note Brute-force complexity: table length * summed length of keys.
+ */
+const char * lua_table_checkindices(lua_State *L, const char *keys[]);
+
+/** If the value at the top of the stack isn't a table, make it a single-element list. */
+static inline void lua_listify(lua_State *L)
+{
+ if (lua_istable(L, -1))
+ return;
+ lua_createtable(L, 1, 0);
+ lua_insert(L, lua_gettop(L) - 1); /* swap the top two stack elements */
+ lua_pushinteger(L, 1);
+ lua_insert(L, lua_gettop(L) - 1); /* swap the top two stack elements */
+ lua_settable(L, -3);
+}
+
+
+/** Throw a formatted lua error.
+ *
+ * The message will get prefixed by "ERROR: " and supplemented by stack trace.
+ * \return never! It calls lua_error().
+ *
+ * Example:
+ ERROR: not a valid pin_sha256: 'a1Z/3ek=', raw length 5 instead of 32
+ stack traceback:
+ [C]: in function 'tls_client'
+ /PathToPREFIX/lib/kdns_modules/policy.lua:175: in function 'TLS_FORWARD'
+ /PathToConfig.lua:46: in main chunk
+ */
+KR_PRINTF(2) KR_NORETURN KR_COLD
+void lua_error_p(lua_State *L, const char *fmt, ...);
+/** @internal Annotate for static checkers. */
+KR_NORETURN int lua_error(lua_State *L);
+
+/** Shortcut for common case. */
+static inline void lua_error_maybe(lua_State *L, int err)
+{
+ if (err) lua_error_p(L, "%s", kr_strerror(err));
+}
+
+static inline int execute_callback(lua_State *L, int argc)
+{
+ int ret = engine_pcall(L, argc);
+ if (ret != 0) {
+ kr_log_error("error: %s\n", lua_tostring(L, -1));
+ }
+ /* Clear the stack, there may be event a/o enything returned */
+ lua_settop(L, 0);
+ return ret;
+}
+
+/** Push a pointer as heavy/full userdata.
+ *
+ * It's useful as a replacement of lua_pushlightuserdata(),
+ * but note that it behaves differently in lua (converts to pointer-to-pointer).
+ */
+static inline void lua_pushpointer(lua_State *L, void *p)
+{
+ void *addr = lua_newuserdata(L, sizeof(void *));
+ assert(addr);
+ memcpy(addr, &p, sizeof(void *));
+}
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "bindings/impl.h"
+
+
+/** List loaded modules */
+static int mod_list(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ lua_newtable(L);
+ for (unsigned i = 0; i < engine->modules.len; ++i) {
+ struct kr_module *module = engine->modules.at[i];
+ lua_pushstring(L, module->name);
+ lua_rawseti(L, -2, i + 1);
+ }
+ return 1;
+}
+
+/** Load module. */
+static int mod_load(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n != 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "expected 'load(string name)'");
+ /* Parse precedence declaration */
+ char *declaration = strdup(lua_tostring(L, 1));
+ if (!declaration)
+ return kr_error(ENOMEM);
+ const char *name = strtok(declaration, " ");
+ const char *precedence = strtok(NULL, " ");
+ const char *ref = strtok(NULL, " ");
+ /* Load engine module */
+ struct engine *engine = engine_luaget(L);
+ int ret = engine_register(engine, name, precedence, ref);
+ free(declaration);
+ if (ret != 0) {
+ if (ret == kr_error(EIDRM)) {
+ lua_error_p(L, "referenced module not found");
+ } else {
+ lua_error_maybe(L, ret);
+ }
+ }
+
+ lua_pushboolean(L, 1);
+ return 1;
+}
+
+/** Unload module. */
+static int mod_unload(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n != 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "expected 'unload(string name)'");
+ /* Unload engine module */
+ struct engine *engine = engine_luaget(L);
+ int ret = engine_unregister(engine, lua_tostring(L, 1));
+ lua_error_maybe(L, ret);
+
+ lua_pushboolean(L, 1);
+ return 1;
+}
+
+int kr_bindings_modules(lua_State *L)
+{
+ static const luaL_Reg lib[] = {
+ { "list", mod_list },
+ { "load", mod_load },
+ { "unload", mod_unload },
+ { NULL, NULL }
+ };
+
+ luaL_register(L, "modules", lib);
+ return 1;
+}
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "bindings/impl.h"
+
+#include "contrib/base64.h"
+#include "network.h"
+#include "tls.h"
+#include "worker.h"
+
+#include <stdlib.h>
+
+/** Table and next index on top of stack -> append entries for given endpoint_array_t. */
+static int net_list_add(const char *key, void *val, void *ext)
+{
+ lua_State *L = (lua_State *)ext;
+ lua_Integer i = lua_tointeger(L, -1);
+ endpoint_array_t *ep_array = val;
+ for (int j = 0; j < ep_array->len; ++j) {
+ struct endpoint *ep = &ep_array->at[j];
+ lua_newtable(L); // connection tuple
+
+ if (ep->flags.kind) {
+ lua_pushstring(L, ep->flags.kind);
+ } else if (ep->flags.tls) {
+ lua_pushliteral(L, "tls");
+ } else {
+ lua_pushliteral(L, "dns");
+ }
+ lua_setfield(L, -2, "kind");
+
+ lua_newtable(L); // "transport" table
+
+ lua_pushboolean(L, ep->flags.freebind);
+ lua_setfield(L, -2, "freebind");
+
+ switch (ep->family) {
+ case AF_INET:
+ lua_pushliteral(L, "inet4");
+ break;
+ case AF_INET6:
+ lua_pushliteral(L, "inet6");
+ break;
+ case AF_UNIX:
+ lua_pushliteral(L, "unix");
+ break;
+ default:
+ lua_pushliteral(L, "invalid");
+ assert(!EINVAL);
+ }
+ lua_setfield(L, -2, "family");
+
+ lua_pushstring(L, key);
+ if (ep->family != AF_UNIX) {
+ lua_setfield(L, -2, "ip");
+ } else {
+ lua_setfield(L, -2, "path");
+ }
+
+ if (ep->family != AF_UNIX) {
+ lua_pushinteger(L, ep->port);
+ lua_setfield(L, -2, "port");
+ }
+
+ if (ep->family == AF_UNIX) {
+ lua_pushliteral(L, "stream");
+ } else if (ep->flags.sock_type == SOCK_STREAM) {
+ lua_pushliteral(L, "tcp");
+ } else if (ep->flags.sock_type == SOCK_DGRAM) {
+ lua_pushliteral(L, "udp");
+ } else {
+ assert(!EINVAL);
+ lua_pushliteral(L, "invalid");
+ }
+ lua_setfield(L, -2, "protocol");
+
+ lua_setfield(L, -2, "transport");
+
+ lua_settable(L, -3);
+ i++;
+ lua_pushinteger(L, i);
+ }
+ return kr_ok();
+}
+
+/** List active endpoints. */
+static int net_list(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ lua_newtable(L);
+ lua_pushinteger(L, 1);
+ map_walk(&engine->net.endpoints, net_list_add, L);
+ lua_pop(L, 1);
+ return 1;
+}
+
+/** Listen on an address list represented by the top of lua stack.
+ * \note kind ownership is not transferred
+ * \return success */
+static bool net_listen_addrs(lua_State *L, int port, bool tls, const char *kind, bool freebind)
+{
+ /* Case: table with 'addr' field; only follow that field directly. */
+ lua_getfield(L, -1, "addr");
+ if (!lua_isnil(L, -1)) {
+ lua_replace(L, -2);
+ } else {
+ lua_pop(L, 1);
+ }
+
+ /* Case: string, representing a single address. */
+ const char *str = lua_tostring(L, -1);
+ if (str != NULL) {
+ struct engine *engine = engine_luaget(L);
+ int ret = 0;
+ endpoint_flags_t flags = { .tls = tls, .freebind = freebind };
+ if (!kind && !flags.tls) { /* normal UDP */
+ flags.sock_type = SOCK_DGRAM;
+ ret = network_listen(&engine->net, str, port, flags);
+ }
+ if (!kind && ret == 0) { /* common for normal TCP and TLS */
+ flags.sock_type = SOCK_STREAM;
+ ret = network_listen(&engine->net, str, port, flags);
+ }
+ if (kind) {
+ flags.kind = strdup(kind);
+ flags.sock_type = SOCK_STREAM; /* TODO: allow to override this? */
+ ret = network_listen(&engine->net, str, port, flags);
+ }
+ if (ret != 0) {
+ const char *stype = flags.sock_type == SOCK_DGRAM ? "UDP" : "TCP";
+ kr_log_error("[system] bind to '%s@%d' (%s): %s\n",
+ str, port, stype, kr_strerror(ret));
+ }
+ return ret == 0;
+ }
+
+ /* Last case: table where all entries are added recursively. */
+ if (!lua_istable(L, -1))
+ lua_error_p(L, "bad type for address");
+ lua_pushnil(L);
+ while (lua_next(L, -2)) {
+ if (!net_listen_addrs(L, port, tls, kind, freebind))
+ return false;
+ lua_pop(L, 1);
+ }
+ return true;
+}
+
+static bool table_get_flag(lua_State *L, int index, const char *key, bool def)
+{
+ bool result = def;
+ lua_getfield(L, index, key);
+ if (lua_isboolean(L, -1)) {
+ result = lua_toboolean(L, -1);
+ }
+ lua_pop(L, 1);
+ return result;
+}
+
+/** Listen on endpoint. */
+static int net_listen(lua_State *L)
+{
+ /* Check parameters */
+ int n = lua_gettop(L);
+ if (n < 1 || n > 3) {
+ lua_error_p(L, "expected one to three arguments; usage:\n"
+ "net.listen(addressses, [port = " STR(KR_DNS_PORT)
+ ", flags = {tls = (port == " STR(KR_DNS_TLS_PORT) ")}])\n");
+ }
+
+ int port = KR_DNS_PORT;
+ if (n > 1) {
+ if (lua_isnumber(L, 2)) {
+ port = lua_tointeger(L, 2);
+ } else
+ if (!lua_isnil(L, 2)) {
+ lua_error_p(L, "wrong type of second parameter (port number)");
+ }
+ }
+
+ bool tls = (port == KR_DNS_TLS_PORT);
+ bool freebind = false;
+ const char *kind = NULL;
+ if (n > 2 && !lua_isnil(L, 3)) {
+ if (!lua_istable(L, 3))
+ lua_error_p(L, "wrong type of third parameter (table expected)");
+ tls = table_get_flag(L, 3, "tls", tls);
+ freebind = table_get_flag(L, 3, "freebind", tls);
+
+ lua_getfield(L, 3, "kind");
+ const char *k = lua_tostring(L, -1);
+ if (k && strcasecmp(k, "dns") == 0) {
+ tls = false;
+ } else
+ if (k && strcasecmp(k, "tls") == 0) {
+ tls = true;
+ } else
+ if (k) {
+ kind = k;
+ }
+ }
+
+ /* Memory management of `kind` string is difficult due to longjmp etc.
+ * Pop will unreference the lua value, so we store it on C stack instead (!) */
+ const int kind_alen = kind ? strlen(kind) + 1 : 1 /* 0 length isn't C standard */;
+ char kind_buf[kind_alen];
+ if (kind) {
+ memcpy(kind_buf, kind, kind_alen);
+ kind = kind_buf;
+ }
+
+ /* Now focus on the first argument. */
+ lua_settop(L, 1);
+ if (!net_listen_addrs(L, port, tls, kind, freebind))
+ lua_error_p(L, "net.listen() failed to bind");
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+/** Close endpoint. */
+static int net_close(lua_State *L)
+{
+ /* Check parameters */
+ const int n = lua_gettop(L);
+ bool ok = (n == 1 || n == 2) && lua_isstring(L, 1);
+ const char *addr = lua_tostring(L, 1);
+ int port;
+ if (ok && (n < 2 || lua_isnil(L, 2))) {
+ port = -1;
+ } else if (ok) {
+ ok = lua_isnumber(L, 2);
+ port = lua_tointeger(L, 2);
+ ok = ok && port >= 0 && port <= 65535;
+ }
+ if (!ok)
+ lua_error_p(L, "expected 'close(string addr, [number port])'");
+
+ struct network *net = &engine_luaget(L)->net;
+ int ret = network_close(net, addr, port);
+ lua_pushboolean(L, ret == 0);
+ return 1;
+}
+
+/** List available interfaces. */
+static int net_interfaces(lua_State *L)
+{
+ /* Retrieve interface list */
+ int count = 0;
+ char buf[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
+ uv_interface_address_t *info = NULL;
+ uv_interface_addresses(&info, &count);
+ lua_newtable(L);
+ for (int i = 0; i < count; ++i) {
+ uv_interface_address_t iface = info[i];
+ lua_getfield(L, -1, iface.name);
+ if (lua_isnil(L, -1)) {
+ lua_pop(L, 1);
+ lua_newtable(L);
+ }
+
+ /* Address */
+ lua_getfield(L, -1, "addr");
+ if (lua_isnil(L, -1)) {
+ lua_pop(L, 1);
+ lua_newtable(L);
+ }
+ if (iface.address.address4.sin_family == AF_INET) {
+ uv_ip4_name(&iface.address.address4, buf, sizeof(buf));
+ } else if (iface.address.address4.sin_family == AF_INET6) {
+ uv_ip6_name(&iface.address.address6, buf, sizeof(buf));
+ } else {
+ buf[0] = '\0';
+ }
+ lua_pushstring(L, buf);
+ lua_rawseti(L, -2, lua_objlen(L, -2) + 1);
+ lua_setfield(L, -2, "addr");
+
+ /* Hardware address. */
+ char *p = buf;
+ for (int k = 0; k < sizeof(iface.phys_addr); ++k) {
+ sprintf(p, "%.2x:", (uint8_t)iface.phys_addr[k]);
+ p += 3;
+ }
+ p[-1] = '\0';
+ lua_pushstring(L, buf);
+ lua_setfield(L, -2, "mac");
+
+ /* Push table */
+ lua_setfield(L, -2, iface.name);
+ }
+ uv_free_interface_addresses(info, count);
+
+ return 1;
+}
+
+/** Set UDP maximum payload size. */
+static int net_bufsize(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ knot_rrset_t *opt_rr = engine->resolver.opt_rr;
+ if (!lua_isnumber(L, 1)) {
+ lua_pushinteger(L, knot_edns_get_payload(opt_rr));
+ return 1;
+ }
+ int bufsize = lua_tointeger(L, 1);
+ if (bufsize < 512 || bufsize > UINT16_MAX)
+ lua_error_p(L, "bufsize must be within <512, " STR(UINT16_MAX) ">");
+ knot_edns_set_payload(opt_rr, (uint16_t) bufsize);
+ return 0;
+}
+
+/** Set TCP pipelining size. */
+static int net_pipeline(lua_State *L)
+{
+ struct worker_ctx *worker = the_worker;
+ if (!worker) {
+ return 0;
+ }
+ if (!lua_isnumber(L, 1)) {
+ lua_pushinteger(L, worker->tcp_pipeline_max);
+ return 1;
+ }
+ int len = lua_tointeger(L, 1);
+ if (len < 0 || len > UINT16_MAX)
+ lua_error_p(L, "tcp_pipeline must be within <0, " STR(UINT16_MAX) ">");
+ worker->tcp_pipeline_max = len;
+ lua_pushinteger(L, len);
+ return 1;
+}
+
+static int net_tls(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ if (!engine) {
+ return 0;
+ }
+ struct network *net = &engine->net;
+ if (!net) {
+ return 0;
+ }
+
+ /* Only return current credentials. */
+ if (lua_gettop(L) == 0) {
+ /* No credentials configured yet. */
+ if (!net->tls_credentials) {
+ return 0;
+ }
+ lua_newtable(L);
+ lua_pushstring(L, net->tls_credentials->tls_cert);
+ lua_setfield(L, -2, "cert_file");
+ lua_pushstring(L, net->tls_credentials->tls_key);
+ lua_setfield(L, -2, "key_file");
+ return 1;
+ }
+
+ if ((lua_gettop(L) != 2) || !lua_isstring(L, 1) || !lua_isstring(L, 2))
+ lua_error_p(L, "net.tls takes two parameters: (\"cert_file\", \"key_file\")");
+
+ int r = tls_certificate_set(net, lua_tostring(L, 1), lua_tostring(L, 2));
+ lua_error_maybe(L, r);
+
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+/** Return a lua table with TLS authentication parameters.
+ * The format is the same as passed to policy.TLS_FORWARD();
+ * more precisely, it's in a compatible canonical form. */
+static int tls_params2lua(lua_State *L, trie_t *params)
+{
+ lua_newtable(L);
+ if (!params) /* Allowed special case. */
+ return 1;
+ trie_it_t *it;
+ size_t list_index = 0;
+ for (it = trie_it_begin(params); !trie_it_finished(it); trie_it_next(it)) {
+ /* Prepare table for the current address
+ * and its index in the returned list. */
+ lua_pushinteger(L, ++list_index);
+ lua_createtable(L, 0, 2);
+
+ /* Get the "addr#port" string... */
+ size_t ia_len;
+ const char *key = trie_it_key(it, &ia_len);
+ int af = AF_UNSPEC;
+ if (ia_len == 2 + sizeof(struct in_addr)) {
+ af = AF_INET;
+ } else if (ia_len == 2 + sizeof(struct in6_addr)) {
+ af = AF_INET6;
+ }
+ if (!key || af == AF_UNSPEC) {
+ assert(false);
+ lua_error_p(L, "internal error: bad IP address");
+ }
+ uint16_t port;
+ memcpy(&port, key, sizeof(port));
+ port = ntohs(port);
+ const char *ia = key + sizeof(port);
+ char str[INET6_ADDRSTRLEN + 1 + 5 + 1];
+ size_t len = sizeof(str);
+ if (kr_ntop_str(af, ia, port, str, &len) != kr_ok()) {
+ assert(false);
+ lua_error_p(L, "internal error: bad IP address conversion");
+ }
+ /* ...and push it as [1]. */
+ lua_pushinteger(L, 1);
+ lua_pushlstring(L, str, len - 1 /* len includes '\0' */);
+ lua_settable(L, -3);
+
+ const tls_client_param_t *e = *trie_it_val(it);
+ if (!e)
+ lua_error_p(L, "internal problem - NULL entry for %s", str);
+
+ /* .hostname = */
+ if (e->hostname) {
+ lua_pushstring(L, e->hostname);
+ lua_setfield(L, -2, "hostname");
+ }
+
+ /* .ca_files = */
+ if (e->ca_files.len) {
+ lua_createtable(L, e->ca_files.len, 0);
+ for (size_t i = 0; i < e->ca_files.len; ++i) {
+ lua_pushinteger(L, i + 1);
+ lua_pushstring(L, e->ca_files.at[i]);
+ lua_settable(L, -3);
+ }
+ lua_setfield(L, -2, "ca_files");
+ }
+
+ /* .pin_sha256 = ... ; keep sane indentation via goto. */
+ if (!e->pins.len) goto no_pins;
+ lua_createtable(L, e->pins.len, 0);
+ for (size_t i = 0; i < e->pins.len; ++i) {
+ uint8_t pin_base64[TLS_SHA256_BASE64_BUFLEN];
+ int err = base64_encode(e->pins.at[i], TLS_SHA256_RAW_LEN,
+ pin_base64, sizeof(pin_base64));
+ if (err < 0) {
+ assert(false);
+ lua_error_p(L,
+ "internal problem when converting pin_sha256: %s",
+ kr_strerror(err));
+ }
+ lua_pushinteger(L, i + 1);
+ lua_pushlstring(L, (const char *)pin_base64, err);
+ /* pin_base64 isn't 0-terminated ^^^ */
+ lua_settable(L, -3);
+ }
+ lua_setfield(L, -2, "pin_sha256");
+
+ no_pins:/* .insecure = */
+ if (e->insecure) {
+ lua_pushboolean(L, true);
+ lua_setfield(L, -2, "insecure");
+ }
+ /* Now the whole table is pushed atop the returned list. */
+ lua_settable(L, -3);
+ }
+ trie_it_free(it);
+ return 1;
+}
+
+static inline int cmp_sha256(const void *p1, const void *p2)
+{
+ return memcmp(*(char * const *)p1, *(char * const *)p2, TLS_SHA256_RAW_LEN);
+}
+static int net_tls_client(lua_State *L)
+{
+ /* TODO idea: allow starting the lua table with *multiple* IP targets,
+ * meaning the authentication config should be applied to each.
+ */
+ struct network *net = &engine_luaget(L)->net;
+ if (lua_gettop(L) == 0)
+ return tls_params2lua(L, net->tls_client_params);
+ /* Various basic sanity-checking. */
+ if (lua_gettop(L) != 1 || !lua_istable(L, 1))
+ lua_error_maybe(L, EINVAL);
+ /* check that only allowed keys are present */
+ {
+ const char *bad_key = lua_table_checkindices(L, (const char *[])
+ { "1", "hostname", "ca_file", "pin_sha256", "insecure", NULL });
+ if (bad_key)
+ lua_error_p(L, "found unexpected key '%s'", bad_key);
+ }
+
+ /**** Phase 1: get the parameter into a C struct, incl. parse of CA files,
+ * regardless of the address-pair having an entry already. */
+
+ tls_client_param_t *newcfg = tls_client_param_new();
+ if (!newcfg)
+ lua_error_p(L, "out of memory or something like that :-/");
+ /* Shortcut for cleanup actions needed from now on. */
+ #define ERROR(...) do { \
+ free(newcfg); \
+ lua_error_p(L, __VA_ARGS__); \
+ } while (false)
+
+ /* .hostname - always accepted. */
+ lua_getfield(L, 1, "hostname");
+ if (!lua_isnil(L, -1)) {
+ const char *hn_str = lua_tostring(L, -1);
+ /* Convert to lower-case dname and back, for checking etc. */
+ knot_dname_t dname[KNOT_DNAME_MAXLEN];
+ if (!hn_str || !knot_dname_from_str(dname, hn_str, sizeof(dname)))
+ ERROR("invalid hostname");
+ knot_dname_to_lower(dname);
+ char *h = knot_dname_to_str_alloc(dname);
+ if (!h)
+ ERROR("%s", kr_strerror(ENOMEM));
+ /* Strip the final dot produced by knot_dname_*() */
+ h[strlen(h) - 1] = '\0';
+ newcfg->hostname = h;
+ }
+ lua_pop(L, 1);
+
+ /* .ca_file - it can be a list of paths, contrary to the name. */
+ bool has_ca_file = false;
+ lua_getfield(L, 1, "ca_file");
+ if (!lua_isnil(L, -1)) {
+ if (!newcfg->hostname)
+ ERROR("missing hostname but specifying ca_file");
+ lua_listify(L);
+ array_init(newcfg->ca_files); /*< placate apparently confused scan-build */
+ if (array_reserve(newcfg->ca_files, lua_objlen(L, -1)) != 0) /*< optim. */
+ ERROR("%s", kr_strerror(ENOMEM));
+ /* Iterate over table at the top of the stack.
+ * http://www.lua.org/manual/5.1/manual.html#lua_next */
+ for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+ has_ca_file = true; /* deferred here so that {} -> false */
+ const char *ca_file = lua_tostring(L, -1);
+ if (!ca_file)
+ ERROR("ca_file contains a non-string");
+ /* Let gnutls process it immediately, so garbage gets detected. */
+ int ret = gnutls_certificate_set_x509_trust_file(
+ newcfg->credentials, ca_file, GNUTLS_X509_FMT_PEM);
+ if (ret < 0) {
+ ERROR("failed to import certificate file '%s': %s - %s\n",
+ ca_file, gnutls_strerror_name(ret),
+ gnutls_strerror(ret));
+ } else {
+ kr_log_verbose(
+ "[tls_client] imported %d certs from file '%s'\n",
+ ret, ca_file);
+ }
+
+ ca_file = strdup(ca_file);
+ if (!ca_file || array_push(newcfg->ca_files, ca_file) < 0)
+ ERROR("%s", kr_strerror(ENOMEM));
+ }
+ /* Sort the strings for easier comparison later. */
+ if (newcfg->ca_files.len) {
+ qsort(&newcfg->ca_files.at[0], newcfg->ca_files.len,
+ sizeof(newcfg->ca_files.at[0]), strcmp_p);
+ }
+ }
+ lua_pop(L, 1);
+
+ /* .pin_sha256 */
+ lua_getfield(L, 1, "pin_sha256");
+ if (!lua_isnil(L, -1)) {
+ if (has_ca_file)
+ ERROR("mixing pin_sha256 with ca_file is not supported");
+ lua_listify(L);
+ array_init(newcfg->pins); /*< placate apparently confused scan-build */
+ if (array_reserve(newcfg->pins, lua_objlen(L, -1)) != 0) /*< optim. */
+ ERROR("%s", kr_strerror(ENOMEM));
+ /* Iterate over table at the top of the stack. */
+ for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+ const char *pin = lua_tostring(L, -1);
+ if (!pin)
+ ERROR("pin_sha256 is not a string");
+ uint8_t *pin_raw = malloc(TLS_SHA256_RAW_LEN);
+ /* Push the string early to simplify error processing. */
+ if (!pin_raw || array_push(newcfg->pins, pin_raw) < 0) {
+ assert(false);
+ free(pin_raw);
+ ERROR("%s", kr_strerror(ENOMEM));
+ }
+ int ret = base64_decode((const uint8_t *)pin, strlen(pin),
+ pin_raw, TLS_SHA256_RAW_LEN + 8);
+ if (ret < 0) {
+ ERROR("not a valid pin_sha256: '%s' (length %d), %s\n",
+ pin, (int)strlen(pin), knot_strerror(ret));
+ } else if (ret != TLS_SHA256_RAW_LEN) {
+ ERROR("not a valid pin_sha256: '%s', "
+ "raw length %d instead of "
+ STR(TLS_SHA256_RAW_LEN)"\n",
+ pin, ret);
+ }
+ }
+ /* Sort the raw strings for easier comparison later. */
+ if (newcfg->pins.len) {
+ qsort(&newcfg->pins.at[0], newcfg->pins.len,
+ sizeof(newcfg->pins.at[0]), cmp_sha256);
+ }
+ }
+ lua_pop(L, 1);
+
+ /* .insecure */
+ lua_getfield(L, 1, "insecure");
+ if (lua_isnil(L, -1)) {
+ if (!newcfg->hostname && !newcfg->pins.len)
+ ERROR("no way to authenticate and not set as insecure");
+ } else if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) {
+ newcfg->insecure = true;
+ if (has_ca_file || newcfg->pins.len)
+ ERROR("set as insecure but provided authentication config");
+ } else {
+ ERROR("incorrect value in the 'insecure' field");
+ }
+ lua_pop(L, 1);
+
+ /* Init CAs from system trust store, if needed. */
+ if (!newcfg->insecure && !newcfg->pins.len && !has_ca_file) {
+ int ret = gnutls_certificate_set_x509_system_trust(newcfg->credentials);
+ if (ret <= 0) {
+ ERROR("failed to use system CA certificate store: %s",
+ ret ? gnutls_strerror(ret) : kr_strerror(ENOENT));
+ } else {
+ kr_log_verbose(
+ "[tls_client] imported %d certs from system store\n",
+ ret);
+ }
+ }
+ #undef ERROR
+
+ /**** Phase 2: deal with the C authentication "table". */
+ /* Parse address and port. */
+ lua_pushinteger(L, 1);
+ lua_gettable(L, 1);
+ const char *addr_str = lua_tostring(L, -1);
+ if (!addr_str)
+ lua_error_p(L, "address is not a string");
+ char buf[INET6_ADDRSTRLEN + 1];
+ uint16_t port = 853;
+ const struct sockaddr *addr = NULL;
+ if (kr_straddr_split(addr_str, buf, &port) == kr_ok())
+ addr = kr_straddr_socket(buf, port, NULL);
+ /* Add newcfg into the C map, saving the original into oldcfg. */
+ if (!addr)
+ lua_error_p(L, "address '%s' could not be converted", addr_str);
+ tls_client_param_t **oldcfgp = tls_client_param_getptr(
+ &net->tls_client_params, addr, true);
+ free_const(addr);
+ if (!oldcfgp)
+ lua_error_p(L, "internal error when extending tls_client_params map");
+ tls_client_param_t *oldcfg = *oldcfgp;
+ *oldcfgp = newcfg; /* replace old config in trie with the new one */
+ /* If there was no original entry, it's easy! */
+ if (!oldcfg)
+ return 0;
+
+ /* Check for equality (newcfg vs. oldcfg), and print a warning if not equal.*/
+ const bool ok_h = (!newcfg->hostname && !oldcfg->hostname)
+ || (newcfg->hostname && oldcfg->hostname && strcmp(newcfg->hostname, oldcfg->hostname) == 0);
+ bool ok_ca = newcfg->ca_files.len == oldcfg->ca_files.len;
+ for (int i = 0; ok_ca && i < newcfg->ca_files.len; ++i)
+ ok_ca = strcmp(newcfg->ca_files.at[i], oldcfg->ca_files.at[i]) == 0;
+ bool ok_pins = newcfg->pins.len == oldcfg->pins.len;
+ for (int i = 0; ok_pins && i < newcfg->pins.len; ++i)
+ ok_ca = memcmp(newcfg->pins.at[i], oldcfg->pins.at[i], TLS_SHA256_RAW_LEN) == 0;
+ const bool ok_insecure = newcfg->insecure == oldcfg->insecure;
+ if (!(ok_h && ok_ca && ok_pins && ok_insecure)) {
+ kr_log_info("[tls_client] "
+ "warning: re-defining TLS authentication parameters for %s\n",
+ addr_str);
+ }
+ tls_client_param_unref(oldcfg);
+ return 0;
+}
+
+int net_tls_client_clear(lua_State *L)
+{
+ /* One parameter: address -> convert it to a struct sockaddr. */
+ if (lua_gettop(L) != 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "net.tls_client_clear() requires one parameter (\"address\")");
+ const char *addr_str = lua_tostring(L, 1);
+ char buf[INET6_ADDRSTRLEN + 1];
+ uint16_t port = 853;
+ const struct sockaddr *addr = NULL;
+ if (kr_straddr_split(addr_str, buf, &port) == kr_ok())
+ addr = kr_straddr_socket(buf, port, NULL);
+ if (!addr)
+ lua_error_p(L, "invalid IP address");
+ /* Do the actual removal. */
+ struct network *net = &engine_luaget(L)->net;
+ int r = tls_client_param_remove(net->tls_client_params, addr);
+ free_const(addr);
+ lua_error_maybe(L, r);
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+static int net_tls_padding(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+
+ /* Only return current padding. */
+ if (lua_gettop(L) == 0) {
+ if (engine->resolver.tls_padding < 0) {
+ lua_pushboolean(L, true);
+ return 1;
+ } else if (engine->resolver.tls_padding == 0) {
+ lua_pushboolean(L, false);
+ return 1;
+ }
+ lua_pushinteger(L, engine->resolver.tls_padding);
+ return 1;
+ }
+
+ const char *errstr = "net.tls_padding parameter has to be true, false,"
+ " or a number between <0, " STR(MAX_TLS_PADDING) ">";
+ if (lua_gettop(L) != 1)
+ lua_error_p(L, "%s", errstr);
+ if (lua_isboolean(L, 1)) {
+ bool x = lua_toboolean(L, 1);
+ if (x) {
+ engine->resolver.tls_padding = -1;
+ } else {
+ engine->resolver.tls_padding = 0;
+ }
+ } else if (lua_isnumber(L, 1)) {
+ int padding = lua_tointeger(L, 1);
+ if ((padding < 0) || (padding > MAX_TLS_PADDING))
+ lua_error_p(L, "%s", errstr);
+ engine->resolver.tls_padding = padding;
+ } else {
+ lua_error_p(L, "%s", errstr);
+ }
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+/** Shorter salt can't contain much entropy. */
+#define net_tls_sticket_MIN_SECRET_LEN 32
+
+static int net_tls_sticket_secret_string(lua_State *L)
+{
+ struct network *net = &engine_luaget(L)->net;
+
+ size_t secret_len;
+ const char *secret;
+
+ if (lua_gettop(L) == 0) {
+ /* Zero-length secret, implying random key. */
+ secret_len = 0;
+ secret = NULL;
+ } else {
+ if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) {
+ lua_error_p(L,
+ "net.tls_sticket_secret takes one parameter: (\"secret string\")");
+ }
+ secret = lua_tolstring(L, 1, &secret_len);
+ if (secret_len < net_tls_sticket_MIN_SECRET_LEN || !secret) {
+ lua_error_p(L, "net.tls_sticket_secret - the secret is shorter than "
+ STR(net_tls_sticket_MIN_SECRET_LEN) " bytes");
+ }
+ }
+
+ tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx);
+ net->tls_session_ticket_ctx =
+ tls_session_ticket_ctx_create(net->loop, secret, secret_len);
+ if (net->tls_session_ticket_ctx == NULL) {
+ lua_error_p(L,
+ "net.tls_sticket_secret_string - can't create session ticket context");
+ }
+
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+static int net_tls_sticket_secret_file(lua_State *L)
+{
+ if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) {
+ lua_error_p(L,
+ "net.tls_sticket_secret_file takes one parameter: (\"file name\")");
+ }
+
+ const char *file_name = lua_tostring(L, 1);
+ if (strlen(file_name) == 0)
+ lua_error_p(L, "net.tls_sticket_secret_file - empty file name");
+
+ FILE *fp = fopen(file_name, "r");
+ if (fp == NULL) {
+ lua_error_p(L, "net.tls_sticket_secret_file - can't open file '%s': %s",
+ file_name, strerror(errno));
+ }
+
+ char secret_buf[TLS_SESSION_TICKET_SECRET_MAX_LEN];
+ const size_t secret_len = fread(secret_buf, 1, sizeof(secret_buf), fp);
+ int err = ferror(fp);
+ if (err) {
+ lua_error_p(L,
+ "net.tls_sticket_secret_file - error reading from file '%s': %s",
+ file_name, strerror(err));
+ }
+ if (secret_len < net_tls_sticket_MIN_SECRET_LEN) {
+ lua_error_p(L,
+ "net.tls_sticket_secret_file - file '%s' is shorter than "
+ STR(net_tls_sticket_MIN_SECRET_LEN) " bytes",
+ file_name);
+ }
+ fclose(fp);
+
+ struct network *net = &engine_luaget(L)->net;
+
+ tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx);
+ net->tls_session_ticket_ctx =
+ tls_session_ticket_ctx_create(net->loop, secret_buf, secret_len);
+ if (net->tls_session_ticket_ctx == NULL) {
+ lua_error_p(L,
+ "net.tls_sticket_secret_file - can't create session ticket context");
+ }
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+static int net_outgoing(lua_State *L, int family)
+{
+ union inaddr *addr;
+ if (family == AF_INET)
+ addr = (union inaddr*)&the_worker->out_addr4;
+ else
+ addr = (union inaddr*)&the_worker->out_addr6;
+
+ if (lua_gettop(L) == 0) { /* Return the current value. */
+ if (addr->ip.sa_family == AF_UNSPEC) {
+ lua_pushnil(L);
+ return 1;
+ }
+ if (addr->ip.sa_family != family) {
+ assert(false);
+ lua_error_p(L, "bad address family");
+ }
+ char addr_buf[INET6_ADDRSTRLEN];
+ int err;
+ if (family == AF_INET)
+ err = uv_ip4_name(&addr->ip4, addr_buf, sizeof(addr_buf));
+ else
+ err = uv_ip6_name(&addr->ip6, addr_buf, sizeof(addr_buf));
+ lua_error_maybe(L, err);
+ lua_pushstring(L, addr_buf);
+ return 1;
+ }
+
+ if ((lua_gettop(L) != 1) || (!lua_isstring(L, 1) && !lua_isnil(L, 1)))
+ lua_error_p(L, "net.outgoing_vX takes one address string parameter or nil");
+
+ if (lua_isnil(L, 1)) {
+ addr->ip.sa_family = AF_UNSPEC;
+ return 1;
+ }
+
+ const char *addr_str = lua_tostring(L, 1);
+ int err;
+ if (family == AF_INET)
+ err = uv_ip4_addr(addr_str, 0, &addr->ip4);
+ else
+ err = uv_ip6_addr(addr_str, 0, &addr->ip6);
+ if (err)
+ lua_error_p(L, "net.outgoing_vX: failed to parse the address");
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+static int net_outgoing_v4(lua_State *L) { return net_outgoing(L, AF_INET); }
+static int net_outgoing_v6(lua_State *L) { return net_outgoing(L, AF_INET6); }
+
+static int net_update_timeout(lua_State *L, uint64_t *timeout, const char *name)
+{
+ /* Only return current idle timeout. */
+ if (lua_gettop(L) == 0) {
+ lua_pushinteger(L, *timeout);
+ return 1;
+ }
+
+ if ((lua_gettop(L) != 1))
+ lua_error_p(L, "%s takes one parameter: (\"idle timeout\")", name);
+
+ if (lua_isnumber(L, 1)) {
+ int idle_timeout = lua_tointeger(L, 1);
+ if (idle_timeout <= 0)
+ lua_error_p(L, "%s parameter has to be positive number", name);
+ *timeout = idle_timeout;
+ } else {
+ lua_error_p(L, "%s parameter has to be positive number", name);
+ }
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+static int net_tcp_in_idle(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ struct network *net = &engine->net;
+
+ return net_update_timeout(L, &net->tcp.in_idle_timeout, "net.tcp_in_idle");
+}
+
+static int net_tls_handshake_timeout(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ struct network *net = &engine->net;
+
+ return net_update_timeout(L, &net->tcp.tls_handshake_timeout, "net.tls_handshake_timeout");
+}
+
+static int net_bpf_set(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+
+ if (lua_gettop(L) != 1 || !lua_isnumber(L, 1)) {
+ lua_error_p(L, "net.bpf_set(fd) takes one parameter:"
+ " the open file descriptor of a loaded BPF program");
+ }
+
+#if __linux__
+
+ struct network *net = &engine->net;
+ int progfd = lua_tointeger(L, 1);
+ if (progfd == 0) {
+ /* conversion error despite that fact
+ * that lua_isnumber(L, 1) has returned true.
+ * Real or stdin? */
+ lua_error_p(L, "failed to convert parameter");
+ }
+ lua_pop(L, 1);
+
+ if (network_set_bpf(net, progfd) == 0) {
+ lua_error_p(L, "failed to attach BPF program to some networks: %s",
+ kr_strerror(errno));
+ }
+
+ lua_pushboolean(L, 1);
+ return 1;
+
+#endif
+ lua_error_p(L, "BPF is not supported on this operating system");
+}
+
+static int net_bpf_clear(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+
+ if (lua_gettop(L) != 0)
+ lua_error_p(L, "net.bpf_clear() does not take any parameters");
+
+#if __linux__
+
+ struct network *net = &engine->net;
+ network_clear_bpf(net);
+
+ lua_pushboolean(L, 1);
+ return 1;
+
+#endif
+ lua_error_p(L, "BPF is not supported on this operating system");
+}
+
+static int net_register_endpoint_kind(lua_State *L)
+{
+ const int param_count = lua_gettop(L);
+ if (param_count != 1 && param_count != 2)
+ lua_error_p(L, "expected one or two parameters");
+ if (!lua_isstring(L, 1)) {
+ lua_error_p(L, "incorrect kind '%s'", lua_tostring(L, 1));
+ }
+ size_t kind_len;
+ const char *kind = lua_tolstring(L, 1, &kind_len);
+ struct network *net = &engine_luaget(L)->net;
+
+ /* Unregistering */
+ if (param_count == 1) {
+ void *val;
+ if (trie_del(net->endpoint_kinds, kind, kind_len, &val) == KNOT_EOK) {
+ const int fun_id = (char *)val - (char *)NULL;
+ luaL_unref(L, LUA_REGISTRYINDEX, fun_id);
+ return 0;
+ }
+ lua_error_p(L, "attempt to unregister unknown kind '%s'\n", kind);
+ } /* else */
+
+ /* Registering */
+ assert(param_count == 2);
+ if (!lua_isfunction(L, 2)) {
+ lua_error_p(L, "second parameter: expected function but got %s\n",
+ lua_typename(L, lua_type(L, 2)));
+ }
+ const int fun_id = luaL_ref(L, LUA_REGISTRYINDEX);
+ /* ^^ The function is on top of the stack, incidentally. */
+ void **pp = trie_get_ins(net->endpoint_kinds, kind, kind_len);
+ if (!pp) lua_error_maybe(L, kr_error(ENOMEM));
+ if (*pp != NULL || !strcasecmp(kind, "dns") || !strcasecmp(kind, "tls"))
+ lua_error_p(L, "attempt to register known kind '%s'\n", kind);
+ *pp = (char *)NULL + fun_id;
+ /* We don't attempt to engage correspoinding endpoints now.
+ * That's the job for network_engage_endpoints() later. */
+ return 0;
+}
+
+int kr_bindings_net(lua_State *L)
+{
+ static const luaL_Reg lib[] = {
+ { "list", net_list },
+ { "listen", net_listen },
+ { "close", net_close },
+ { "interfaces", net_interfaces },
+ { "bufsize", net_bufsize },
+ { "tcp_pipeline", net_pipeline },
+ { "tls", net_tls },
+ { "tls_server", net_tls },
+ { "tls_client", net_tls_client },
+ { "tls_client_clear", net_tls_client_clear },
+ { "tls_padding", net_tls_padding },
+ { "tls_sticket_secret", net_tls_sticket_secret_string },
+ { "tls_sticket_secret_file", net_tls_sticket_secret_file },
+ { "outgoing_v4", net_outgoing_v4 },
+ { "outgoing_v6", net_outgoing_v6 },
+ { "tcp_in_idle", net_tcp_in_idle },
+ { "tls_handshake_timeout", net_tls_handshake_timeout },
+ { "bpf_set", net_bpf_set },
+ { "bpf_clear", net_bpf_clear },
+ { "register_endpoint_kind", net_register_endpoint_kind },
+ { NULL, NULL }
+ };
+ luaL_register(L, "net", lib);
+ return 1;
+}
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "bindings/impl.h"
+
+#include "contrib/base64.h"
+#include "watcher.h"
+
+#include <stdlib.h>
+
+
+
+static bool table_get_flag(lua_State *L, int index, const char *key, bool def)
+{
+ bool result = def;
+ lua_getfield(L, index, key);
+ if (lua_isboolean(L, -1)) {
+ result = lua_toboolean(L, -1);
+ }
+ lua_pop(L, 1);
+ return result;
+}
+
+static int watcher_server(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ if (!engine) {
+ return 0;
+ }
+ struct watcher_context *watcher = &engine->watcher;
+ if (!watcher) {
+ return 0;
+ }
+
+ /* Only return current credentials. */
+ if (lua_gettop(L) == 0) {
+ lua_newtable(L);
+ lua_pushstring(L, watcher.config.auto_start);
+ lua_setfield(L, -2, "auto_start");
+ lua_pushstring(L, watcher.config.auto_cache_gc);
+ lua_setfield(L, -2, "auto_cache_gc");
+ lua_pushnumber(L, watcher.config.kresd_instances);
+ lua_setfield(L, -2, "kresd_instances");
+ return 1;
+ }
+
+ lua_pushboolean(L, true);
+ return 1;
+}
+
+int kr_bindings_watcher(lua_State *L)
+{
+ static const luaL_Reg lib[] = {
+ { "server", watcher_server },
+
+ { NULL, NULL }
+ };
+ luaL_register(L, "watcher", lib);
+ return 1;
+}
+
--- /dev/null
+/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "bindings/impl.h"
+
+#include "worker.h"
+
+static inline double getseconds(uv_timeval_t *tv)
+{
+ return (double)tv->tv_sec + 0.000001*((double)tv->tv_usec);
+}
+
+/** Return worker statistics. */
+static int wrk_stats(lua_State *L)
+{
+ struct worker_ctx *worker = the_worker;
+ if (!worker) {
+ return 0;
+ }
+ lua_newtable(L);
+ lua_pushnumber(L, worker->stats.queries);
+ lua_setfield(L, -2, "queries");
+ lua_pushnumber(L, worker->stats.concurrent);
+ lua_setfield(L, -2, "concurrent");
+ lua_pushnumber(L, worker->stats.dropped);
+ lua_setfield(L, -2, "dropped");
+
+ lua_pushnumber(L, worker->stats.timeout);
+ lua_setfield(L, -2, "timeout");
+ lua_pushnumber(L, worker->stats.udp);
+ lua_setfield(L, -2, "udp");
+ lua_pushnumber(L, worker->stats.tcp);
+ lua_setfield(L, -2, "tcp");
+ lua_pushnumber(L, worker->stats.tls);
+ lua_setfield(L, -2, "tls");
+ lua_pushnumber(L, worker->stats.ipv4);
+ lua_setfield(L, -2, "ipv4");
+ lua_pushnumber(L, worker->stats.ipv6);
+ lua_setfield(L, -2, "ipv6");
+
+ /* Add subset of rusage that represents counters. */
+ uv_rusage_t rusage;
+ if (uv_getrusage(&rusage) == 0) {
+ lua_pushnumber(L, getseconds(&rusage.ru_utime));
+ lua_setfield(L, -2, "usertime");
+ lua_pushnumber(L, getseconds(&rusage.ru_stime));
+ lua_setfield(L, -2, "systime");
+ lua_pushnumber(L, rusage.ru_majflt);
+ lua_setfield(L, -2, "pagefaults");
+ lua_pushnumber(L, rusage.ru_nswap);
+ lua_setfield(L, -2, "swaps");
+ lua_pushnumber(L, rusage.ru_nvcsw + rusage.ru_nivcsw);
+ lua_setfield(L, -2, "csw");
+ }
+ /* Get RSS */
+ size_t rss = 0;
+ if (uv_resident_set_memory(&rss) == 0) {
+ lua_pushnumber(L, rss);
+ lua_setfield(L, -2, "rss");
+ }
+ return 1;
+}
+
+int kr_bindings_worker(lua_State *L)
+{
+ static const luaL_Reg lib[] = {
+ { "stats", wrk_stats },
+ { NULL, NULL }
+ };
+ luaL_register(L, "worker", lib);
+ return 1;
+}
+
--- /dev/null
+
+#include "lib/utils.h"
+#include "systemd/sd-bus.h"
+
+#include "dbus_control.h"
+
+#define DBUS_SD_NAME "org.freedesktop.systemd1"
+#define DBUS_SD_PATH "/org/freedesktop/systemd1"
+
+#define DBUS_INT_MAN "org.freedesktop.systemd1.Manager"
+#define DBUS_INT_PROP "org.freedesktop.DBus.Properties"
+#define DBUS_INT_UNIT "org.freedesktop.systemd1.Unit"
+
+#define SERVICE_KRESD "kresd@%s.service"
+#define SERVICE_CACHE_GC "kres-cache-gc.service"
+
+#define DBUS_PATH_KRESD "/org/freedesktop/systemd1/unit/kresd_40%s_2eservice"
+#define DBUS_PATH_GC "/org/freedesktop/systemd1/unit/kres_2dcache_2dgc_2eservice"
+
+
+int kresd_get_status(const char *instance, char **status)
+{
+ sd_bus* bus = NULL;
+ void* userdata = NULL;
+ char *instance_service;
+ sd_bus_error err = SD_BUS_ERROR_NULL;
+
+ int ret = sd_bus_default_system(&bus);
+
+ asprintf(&instance_service, SERVICE_KRESD, instance);
+
+ ret = sd_bus_get_property_string(
+ bus,
+ DBUS_SD_NAME,
+ DBUS_PATH_GC,
+ DBUS_INT_UNIT,
+ "ActiveState",
+ &err,
+ status
+ );
+ free(instance_service);
+ return ret;
+}
+
+int cache_gc_get_status(char **status)
+{
+ sd_bus* bus = NULL;
+ void* userdata = NULL;
+ sd_bus_error err = SD_BUS_ERROR_NULL;
+
+ int ret = sd_bus_default_system(&bus);
+
+ ret = sd_bus_get_property_string(
+ bus,
+ DBUS_SD_NAME,
+ DBUS_PATH_GC,
+ DBUS_INT_UNIT,
+ "ActiveState",
+ &err,
+ status
+ );
+ return ret;
+}
+
+int kresd_ctl(const char *method, const char *instance)
+{
+ sd_bus* bus = NULL;
+ char *instance_service;
+ sd_bus_message *reply = NULL;
+ sd_bus_error error = SD_BUS_ERROR_NULL;
+
+ int ret = sd_bus_default_system(&bus);
+
+ asprintf(&instance_service, SERVICE_KRESD, instance);
+
+ ret = sd_bus_call_method(bus, DBUS_SD_NAME, DBUS_SD_PATH, DBUS_INT_MAN,
+ method, &error, &reply, "ss", instance_service, "replace");
+ if (ret < 0) {
+ kr_log_error(
+ "[sdbus] failed to issue method call '%s' %s: %s\n",
+ method, instance_service, error.message);
+ goto cleanup;
+ }
+
+ kr_log_info("[sdbus] %s %s\n", method, instance_service);
+
+ cleanup:
+ free(instance_service);
+ sd_bus_error_free(&error);
+ sd_bus_message_unref(reply);
+
+ return ret;
+}
+
+int cache_gc_ctl(const char *method)
+{
+ sd_bus* bus = NULL;
+ sd_bus_message *reply = NULL;
+ sd_bus_error error = SD_BUS_ERROR_NULL;
+
+ int ret = sd_bus_default_system(&bus);
+
+ ret = sd_bus_call_method(bus, DBUS_SD_NAME, DBUS_SD_PATH, DBUS_INT_MAN,
+ method, &error, &reply, "ss", SERVICE_CACHE_GC, "replace");
+ if (ret < 0) {
+ kr_log_error(
+ "[sdbus] failed to issue method call '%s' %s: %s\n",
+ method, SERVICE_CACHE_GC, error.message);
+ goto cleanup;
+ }
+
+ kr_log_info("[sdbus] %s %s\n", method, SERVICE_CACHE_GC);
+
+ cleanup:
+ sd_bus_error_free(&error);
+ sd_bus_message_unref(reply);
+
+ return ret;
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <uv.h>
+#include <systemd/sd-bus.h>
+
+#define UNIT_START "StartUnit"
+#define UNIT_STOP "StopUnit"
+#define UNIT_RESTART "RestartUnit"
+
+
+typedef struct sdbus_ctx sdbus_ctx_t;
+
+typedef void (*sdbus_cb)(sdbus_ctx_t *sdbus_ctx, int status);
+
+/** Context for sdbus.
+* might add some other fields in future */
+struct sdbus_uv_ctx {
+ sd_bus *bus;
+ sdbus_cb callback;
+ uv_poll_t uv_handle;
+};
+
+int kresd_get_status(const char *instance, char **status);
+
+int cache_gc_get_status(char **status);
+
+int kresd_ctl(const char *method, const char *instance);
+
+int cache_gc_ctl(const char *method);
--- /dev/null
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <contrib/cleanup.h>
+#include <ccan/json/json.h>
+#include <ccan/asprintf/asprintf.h>
+#include <dlfcn.h>
+#include <uv.h>
+#include <unistd.h>
+#include <grp.h>
+#include <pwd.h>
+#include <sys/param.h>
+#include <libzscanner/scanner.h>
+
+#include <lua.h>
+#include <lualib.h>
+#include <lauxlib.h>
+#include "bindings/impl.h"
+
+#include "kresconfig.h"
+#include "engine.h"
+#include "ffimodule.h"
+#include "lib/nsrep.h"
+#include "lib/cache/api.h"
+#include "lib/defines.h"
+#include "lib/cache/cdb_lmdb.h"
+#include "lib/dnssec/ta.h"
+
+#include "watcher.h"
+
+/* Magic defaults for the engine. */
+#ifndef LRU_RTT_SIZE
+#define LRU_RTT_SIZE 65536 /**< NS RTT cache size */
+#endif
+#ifndef LRU_REP_SIZE
+#define LRU_REP_SIZE (LRU_RTT_SIZE / 4) /**< NS reputation cache size */
+#endif
+#ifndef LRU_COOKIES_SIZE
+ #ifdef ENABLE_COOKIES
+ #define LRU_COOKIES_SIZE LRU_RTT_SIZE /**< DNS cookies cache size. */
+ #else
+ #define LRU_COOKIES_SIZE LRU_ASSOC /* simpler than guards everywhere */
+ #endif
+#endif
+
+/**@internal Maximum number of incomplete TCP connections in queue.
+* Default is from Redis and Apache. */
+#ifndef TCP_BACKLOG_DEFAULT
+#define TCP_BACKLOG_DEFAULT 511
+#endif
+
+/* Cleanup engine state every 5 minutes */
+const size_t CLEANUP_TIMER = 5*60*1000;
+
+/* Execute byte code */
+#define l_dobytecode(L, arr, len, name) \
+ (luaL_loadbuffer((L), (arr), (len), (name)) || lua_pcall((L), 0, LUA_MULTRET, 0))
+
+/*
+ * Global bindings.
+ */
+
+
+/** Print help and available commands. */
+static int l_help(lua_State *L)
+{
+ static const char *help_str =
+ "help()\n show this help\n"
+ "quit()\n quit\n"
+ "hostname()\n hostname\n"
+ "package_version()\n return package version\n"
+ "user(name[, group])\n change process user (and group)\n"
+ "verbose(true|false)\n toggle verbose mode\n"
+ "option(opt[, new_val])\n get/set server option\n"
+ "mode(strict|normal|permissive)\n set resolver strictness level\n"
+ "reorder_RR([true|false])\n set/get reordering of RRs within RRsets\n"
+ "resolve(name, type[, class, flags, callback])\n resolve query, callback when it's finished\n"
+ "todname(name)\n convert name to wire format\n"
+ "tojson(val)\n convert value to JSON\n"
+ "map(expr)\n run expression on all workers\n"
+ "net\n network configuration\n"
+ "cache\n network configuration\n"
+ "modules\n modules configuration\n"
+ "kres\n resolver services\n"
+ "trust_anchors\n configure trust anchors\n"
+ ;
+ lua_pushstring(L, help_str);
+ return 1;
+}
+
+static bool update_privileges(int uid, int gid)
+{
+ if ((gid_t)gid != getgid()) {
+ if (setregid(gid, gid) < 0) {
+ return false;
+ }
+ }
+ if ((uid_t)uid != getuid()) {
+ if (setreuid(uid, uid) < 0) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/** Set process user/group. */
+static int l_setuser(lua_State *L)
+{
+ int n = lua_gettop(L);
+ if (n < 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "user(user[, group])");
+
+ /* Fetch UID/GID based on string identifiers. */
+ struct passwd *user_pw = getpwnam(lua_tostring(L, 1));
+ if (!user_pw)
+ lua_error_p(L, "invalid user name");
+ int uid = user_pw->pw_uid;
+ int gid = getgid();
+ if (n > 1 && lua_isstring(L, 2)) {
+ struct group *group_pw = getgrnam(lua_tostring(L, 2));
+ if (!group_pw)
+ lua_error_p(L, "invalid group name");
+ gid = group_pw->gr_gid;
+ }
+ /* Drop privileges */
+ bool ret = update_privileges(uid, gid);
+ if (!ret) {
+ lua_error_maybe(L, errno);
+ }
+ lua_pushboolean(L, ret);
+ return 1;
+}
+
+/** Quit current executable. */
+static int l_quit(lua_State *L)
+{
+ engine_stop(engine_luaget(L));
+ return 0;
+}
+
+/** Toggle verbose mode. */
+static int l_verbose(lua_State *L)
+{
+ if (lua_isboolean(L, 1) || lua_isnumber(L, 1)) {
+ kr_verbose_set(lua_toboolean(L, 1));
+ }
+ lua_pushboolean(L, kr_verbose_status);
+ return 1;
+}
+
+char *engine_get_hostname(struct engine *engine) {
+ static char hostname_str[KNOT_DNAME_MAXLEN];
+ if (!engine) {
+ return NULL;
+ }
+
+ if (!engine->hostname) {
+ if (gethostname(hostname_str, sizeof(hostname_str)) != 0)
+ return NULL;
+ return hostname_str;
+ }
+ return engine->hostname;
+}
+
+int engine_set_hostname(struct engine *engine, const char *hostname) {
+ if (!engine || !hostname) {
+ return kr_error(EINVAL);
+ }
+
+ char *new_hostname = strdup(hostname);
+ if (!new_hostname) {
+ return kr_error(ENOMEM);
+ }
+ if (engine->hostname) {
+ free(engine->hostname);
+ }
+ engine->hostname = new_hostname;
+ //network_new_hostname(&engine->net, engine);
+
+ return 0;
+}
+
+/** Return hostname. */
+static int l_hostname(lua_State *L)
+{
+ struct engine *engine = engine_luaget(L);
+ if (lua_gettop(L) == 0) {
+ lua_pushstring(L, engine_get_hostname(engine));
+ return 1;
+ }
+ if ((lua_gettop(L) != 1) || !lua_isstring(L, 1))
+ lua_error_p(L, "hostname takes at most one parameter: (\"fqdn\")");
+
+ if (engine_set_hostname(engine, lua_tostring(L, 1)) != 0)
+ lua_error_p(L, "setting hostname failed");
+
+ lua_pushstring(L, engine_get_hostname(engine));
+ return 1;
+}
+
+/** Return server package version. */
+static int l_package_version(lua_State *L)
+{
+ lua_pushliteral(L, PACKAGE_VERSION);
+ return 1;
+}
+
+/** Load root hints from zonefile. */
+// static int l_hint_root_file(lua_State *L)
+// {
+// struct engine *engine = engine_luaget(L);
+// struct kr_context *ctx = &engine->resolver;
+// const char *file = lua_tostring(L, 1);
+
+// const char *err = engine_hint_root_file(ctx, file);
+// if (err) {
+// if (!file) {
+// file = ROOTHINTS;
+// }
+// lua_error_p(L, "error when opening '%s': %s", file, err);
+// } else {
+// lua_pushboolean(L, true);
+// return 1;
+// }
+// }
+
+/** @internal for engine_hint_root_file */
+static void roothints_add(zs_scanner_t *zs)
+{
+ struct kr_zonecut *hints = zs->process.data;
+ if (!hints) {
+ return;
+ }
+ if (zs->r_type == KNOT_RRTYPE_A || zs->r_type == KNOT_RRTYPE_AAAA) {
+ kr_zonecut_add(hints, zs->r_owner, zs->r_data, zs->r_data_length);
+ }
+}
+const char* engine_hint_root_file(struct kr_context *ctx, const char *file)
+{
+ if (!file) {
+ file = ROOTHINTS;
+ }
+ if (strlen(file) == 0 || !ctx) {
+ return "invalid parameters";
+ }
+ struct kr_zonecut *root_hints = &ctx->root_hints;
+
+ zs_scanner_t zs;
+ if (zs_init(&zs, ".", 1, 0) != 0) {
+ return "not enough memory";
+ }
+ if (zs_set_input_file(&zs, file) != 0) {
+ zs_deinit(&zs);
+ return "failed to open root hints file";
+ }
+
+ kr_zonecut_set(root_hints, (const uint8_t *)"");
+ zs_set_processing(&zs, roothints_add, NULL, root_hints);
+ zs_parse_all(&zs);
+ zs_deinit(&zs);
+ return NULL;
+}
+
+/** Unpack JSON object to table */
+static void l_unpack_json(lua_State *L, JsonNode *table)
+{
+ /* Unpack POD */
+ switch(table->tag) {
+ case JSON_STRING: lua_pushstring(L, table->string_); return;
+ case JSON_NUMBER: lua_pushnumber(L, table->number_); return;
+ case JSON_BOOL: lua_pushboolean(L, table->bool_); return;
+ default: break;
+ }
+ /* Unpack object or array into table */
+ lua_newtable(L);
+ JsonNode *node = NULL;
+ json_foreach(node, table) {
+ /* Push node value */
+ switch(node->tag) {
+ case JSON_OBJECT: /* as array */
+ case JSON_ARRAY: l_unpack_json(L, node); break;
+ case JSON_STRING: lua_pushstring(L, node->string_); break;
+ case JSON_NUMBER: lua_pushnumber(L, node->number_); break;
+ case JSON_BOOL: lua_pushboolean(L, node->bool_); break;
+ default: continue;
+ }
+ /* Set table key */
+ if (node->key) {
+ lua_setfield(L, -2, node->key);
+ } else {
+ lua_rawseti(L, -2, lua_objlen(L, -2) + 1);
+ }
+ }
+}
+
+/** @internal Recursive Lua/JSON serialization. */
+static JsonNode *l_pack_elem(lua_State *L, int top)
+{
+ switch(lua_type(L, top)) {
+ case LUA_TSTRING: return json_mkstring(lua_tostring(L, top));
+ case LUA_TNUMBER: return json_mknumber(lua_tonumber(L, top));
+ case LUA_TBOOLEAN: return json_mkbool(lua_toboolean(L, top));
+ case LUA_TTABLE: break; /* Table, iterate it. */
+ default: return json_mknull();
+ }
+ /* Use absolute indexes here, as the table may be nested. */
+ JsonNode *node = NULL;
+ lua_pushnil(L);
+ while(lua_next(L, top) != 0) {
+ bool is_array = false;
+ if (!node) {
+ is_array = (lua_type(L, top + 1) == LUA_TNUMBER);
+ node = is_array ? json_mkarray() : json_mkobject();
+ if (!node) {
+ return NULL;
+ }
+ } else {
+ is_array = node->tag == JSON_ARRAY;
+ }
+
+ /* Insert to array/table. */
+ JsonNode *val = l_pack_elem(L, top + 2);
+ if (is_array) {
+ json_append_element(node, val);
+ } else {
+ const char *key = lua_tostring(L, top + 1);
+ json_append_member(node, key, val);
+ }
+ lua_pop(L, 1);
+ }
+ /* Return empty object for empty tables. */
+ return node ? node : json_mkobject();
+}
+
+/** @internal Serialize to string */
+static char *l_pack_json(lua_State *L, int top)
+{
+ JsonNode *root = l_pack_elem(L, top);
+ if (!root) {
+ return NULL;
+ }
+ char *result = json_encode(root);
+ json_delete(root);
+ return result;
+}
+
+static int l_tojson(lua_State *L)
+{
+ auto_free char *json_str = l_pack_json(L, lua_gettop(L));
+ if (!json_str) {
+ return 0;
+ }
+ lua_pushstring(L, json_str);
+ return 1;
+}
+
+static int l_fromjson(lua_State *L)
+{
+ if (lua_gettop(L) != 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "a JSON string is required");
+
+ const char *json_str = lua_tostring(L, 1);
+ JsonNode *root_node = json_decode(json_str);
+
+ if (!root_node)
+ lua_error_p(L, "invalid JSON string");
+ l_unpack_json(L, root_node);
+ json_delete(root_node);
+
+ return 1;
+}
+
+/** @internal Throw Lua error if expr is false */
+#define expr_checked(expr) \
+ if (!(expr)) { lua_pushboolean(L, false); lua_rawseti(L, -2, lua_objlen(L, -2) + 1); continue; }
+
+static int l_map(lua_State *L)
+{
+ if (lua_gettop(L) != 1 || !lua_isstring(L, 1))
+ lua_error_p(L, "map('string with a lua expression')");
+
+ struct engine *engine = engine_luaget(L);
+ const char *cmd = lua_tostring(L, 1);
+ uint32_t len = strlen(cmd);
+ lua_newtable(L);
+
+ /* Execute on leader instance */
+ int ntop = lua_gettop(L);
+ engine_cmd(L, cmd, true);
+ lua_settop(L, ntop + 1); /* Push only one return value to table */
+ lua_rawseti(L, -2, 1);
+
+ for (size_t i = 0; i < engine->ipc_set.len; ++i) {
+ int fd = engine->ipc_set.at[i];
+ /* Send command */
+ expr_checked(write(fd, &len, sizeof(len)) == sizeof(len));
+ expr_checked(write(fd, cmd, len) == len);
+ /* Read response */
+ uint32_t rlen = 0;
+ if (read(fd, &rlen, sizeof(rlen)) == sizeof(rlen)) {
+ expr_checked(rlen < UINT32_MAX);
+ auto_free char *rbuf = malloc(rlen + 1);
+ expr_checked(rbuf != NULL);
+ expr_checked(read(fd, rbuf, rlen) == rlen);
+ rbuf[rlen] = '\0';
+ /* Unpack from JSON */
+ JsonNode *root_node = json_decode(rbuf);
+ if (root_node) {
+ l_unpack_json(L, root_node);
+ } else {
+ lua_pushlstring(L, rbuf, rlen);
+ }
+ json_delete(root_node);
+ lua_rawseti(L, -2, lua_objlen(L, -2) + 1);
+ continue;
+ }
+ /* Didn't respond */
+ lua_pushboolean(L, false);
+ lua_rawseti(L, -2, lua_objlen(L, -2) + 1);
+ }
+ return 1;
+}
+
+#undef expr_checked
+
+
+/*
+ * Engine API.
+ */
+
+static int init_worker(struct engine *engine)
+{
+ /* Note: it had been zored by engine_init(). */
+ /* Open resolution context */
+ //engine->resolver.trust_anchors = map_make(NULL);
+ //engine->resolver.negative_anchors = map_make(NULL);
+ //engine->resolver.pool = engine->pool;
+ //engine->resolver.modules = &engine->modules;
+ //engine->resolver.cache_rtt_tout_retry_interval = KR_NS_TIMEOUT_RETRY_INTERVAL;
+ /* Create OPT RR */
+ // engine->resolver.opt_rr = mm_alloc(engine->pool, sizeof(knot_rrset_t));
+ // if (!engine->resolver.opt_rr) {
+ // return kr_error(ENOMEM);
+ // }
+ //knot_edns_init(engine->resolver.opt_rr, KR_EDNS_PAYLOAD, 0, KR_EDNS_VERSION, engine->pool);
+
+ /* Use default TLS padding */
+ //engine->resolver.tls_padding = -1;
+
+ /* Empty init; filled via ./lua/config.lua */
+ //kr_zonecut_init(&engine->resolver.root_hints, (const uint8_t *)"", engine->pool);
+
+ /* Open NS rtt + reputation cache */
+ // lru_create(&engine->resolver.cache_rtt, LRU_RTT_SIZE, NULL, NULL);
+ // lru_create(&engine->resolver.cache_rep, LRU_REP_SIZE, NULL, NULL);
+ // lru_create(&engine->resolver.cache_cookie, LRU_COOKIES_SIZE, NULL, NULL);
+
+ /* Load basic modules */
+ // engine_register(engine, "iterate", NULL, NULL);
+ // engine_register(engine, "validate", NULL, NULL);
+ // engine_register(engine, "cache", NULL, NULL);
+
+ return array_push(engine->backends, kr_cdb_lmdb());
+}
+
+static int init_state(struct engine *engine)
+{
+ /* Initialize Lua state */
+ engine->L = luaL_newstate();
+ if (engine->L == NULL) {
+ return kr_error(ENOMEM);
+ }
+ /* Initialize used libraries. */
+ luaL_openlibs(engine->L);
+ /* Global functions */
+ lua_pushcfunction(engine->L, l_help);
+ lua_setglobal(engine->L, "help");
+ lua_pushcfunction(engine->L, l_quit);
+ lua_setglobal(engine->L, "quit");
+ lua_pushcfunction(engine->L, l_hostname);
+ lua_setglobal(engine->L, "hostname");
+ lua_pushcfunction(engine->L, l_package_version);
+ lua_setglobal(engine->L, "package_version");
+ lua_pushcfunction(engine->L, l_verbose);
+ lua_setglobal(engine->L, "verbose");
+ lua_pushcfunction(engine->L, l_setuser);
+ lua_setglobal(engine->L, "user");
+ // lua_pushcfunction(engine->L, l_hint_root_file);
+ // lua_setglobal(engine->L, "_hint_root_file");
+ lua_pushliteral(engine->L, libknot_SONAME);
+ lua_setglobal(engine->L, "libknot_SONAME");
+ lua_pushliteral(engine->L, libzscanner_SONAME);
+ lua_setglobal(engine->L, "libzscanner_SONAME");
+ lua_pushcfunction(engine->L, l_tojson);
+ lua_setglobal(engine->L, "tojson");
+ lua_pushcfunction(engine->L, l_fromjson);
+ lua_setglobal(engine->L, "fromjson");
+ lua_pushcfunction(engine->L, l_map);
+ lua_setglobal(engine->L, "map");
+ lua_pushlightuserdata(engine->L, engine);
+ lua_setglobal(engine->L, "__engine");
+ return kr_ok();
+}
+
+/**
+ * Start luacov measurement and store results to file specified by
+ * KRESD_COVERAGE_STATS environment variable.
+ * Do nothing if the variable is not set.
+ */
+// static void init_measurement(struct engine *engine)
+// {
+// const char * const statspath = getenv("KRESD_COVERAGE_STATS");
+// if (!statspath)
+// return;
+
+// char * snippet = NULL;
+// int ret = asprintf(&snippet,
+// "_luacov_runner = require('luacov.runner')\n"
+// "_luacov_runner.init({\n"
+// " statsfile = '%s',\n"
+// " exclude = {'test', 'tapered', 'lua/5.1'},\n"
+// "})\n"
+// "jit.off()\n", statspath
+// );
+// assert(ret > 0); (void)ret;
+
+// ret = luaL_loadstring(engine->L, snippet);
+// assert(ret == 0);
+// lua_call(engine->L, 0, 0);
+// free(snippet);
+// }
+
+int init_lua(struct engine *engine) {
+ if (!engine) {
+ return kr_error(EINVAL);
+ }
+
+ /* Use libdir path for including Lua scripts */
+ char l_paths[MAXPATHLEN] = { 0 };
+ #pragma GCC diagnostic push
+ #pragma GCC diagnostic ignored "-Wformat" /* %1$ is not in C standard */
+ /* Save original package.path to package._path */
+ snprintf(l_paths, MAXPATHLEN - 1,
+ "if package._path == nil then package._path = package.path end\n"
+ "package.path = '%1$s/?.lua;%1$s/?/init.lua;'..package._path\n"
+ "if package._cpath == nil then package._cpath = package.cpath end\n"
+ "package.cpath = '%1$s/?%2$s;'..package._cpath\n",
+ LIBDIR, LIBEXT);
+ #pragma GCC diagnostic pop
+
+ int ret = l_dobytecode(engine->L, l_paths, strlen(l_paths), "");
+ if (ret != 0) {
+ lua_pop(engine->L, 1);
+ return ret;
+ }
+ return 0;
+}
+
+
+int engine_init(struct engine *engine, knot_mm_t *pool)
+{
+ if (engine == NULL) {
+ return kr_error(EINVAL);
+ }
+
+ memset(engine, 0, sizeof(*engine));
+ engine->pool = pool;
+
+ /* Initialize state */
+ int ret = init_state(engine);
+ if (ret != 0) {
+ engine_deinit(engine);
+ return ret;
+ }
+
+ // init_measurement(engine);
+
+ /* Initialize worker */
+ ret = init_worker(engine);
+ if (ret != 0) {
+ engine_deinit(engine);
+ return ret;
+ }
+
+ /* Initialize watcher */
+ watcher_init(&engine->watcher, uv_default_loop());
+
+ /* Initialize lua */
+ ret = init_lua(engine);
+ if (ret != 0) {
+ engine_deinit(engine);
+ return ret;
+ }
+
+ return ret;
+}
+
+/** Unregister a (found) module */
+static void engine_unload(struct engine *engine, struct kr_module *module)
+{
+ auto_free char *name = module->name ? strdup(module->name) : NULL;
+ kr_module_unload(module); /* beware: lua/C mix, could be confusing */
+ /* Clear in Lua world, but not for embedded modules ('cache' in particular). */
+ if (name && !kr_module_get_embedded(name)) {
+ lua_pushnil(engine->L);
+ lua_setglobal(engine->L, name);
+ }
+ free(module);
+}
+
+void engine_deinit(struct engine *engine)
+{
+ if (engine == NULL) {
+ return;
+ }
+ if (!engine->L) {
+ assert(false);
+ return;
+ }
+ /* Only close sockets and services; no need to clean up mempool. */
+
+ /* Network deinit is split up. We first need to stop listening,
+ * then we can unload modules during which we still want
+ * e.g. the endpoint kind registry to work (inside ->net),
+ * and this registry deinitization uses the lua state. */
+ //network_close_force(&engine->net);
+ for (size_t i = 0; i < engine->ipc_set.len; ++i) {
+ close(engine->ipc_set.at[i]);
+ }
+ for (size_t i = 0; i < engine->modules.len; ++i) {
+ engine_unload(engine, engine->modules.at[i]);
+ }
+ //kr_zonecut_deinit(&engine->resolver.root_hints);
+ //kr_cache_close(&engine->watcher.cache);
+
+ /* The LRUs are currently malloc-ated and need to be freed. */
+ //lru_free(engine->resolver.cache_rtt);
+ //lru_free(engine->resolver.cache_rep);
+ //lru_free(engine->resolver.cache_cookie);
+
+ watcher_deinit(&engine->watcher);
+ ffimodule_deinit(engine->L);
+ lua_close(engine->L);
+
+ /* Free data structures */
+ array_clear(engine->modules);
+ array_clear(engine->backends);
+ array_clear(engine->ipc_set);
+ //kr_ta_clear(&engine->resolver.trust_anchors);
+ //(&engine->resolver.negative_anchors);
+ free(engine->hostname);
+}
+
+int engine_pcall(lua_State *L, int argc)
+{
+ return lua_pcall(L, argc, LUA_MULTRET, 0);
+}
+
+int engine_cmd(lua_State *L, const char *str, bool raw)
+{
+ if (L == NULL) {
+ return kr_error(ENOEXEC);
+ }
+
+ /* Evaluate results */
+ lua_getglobal(L, "eval_cmd");
+ lua_pushstring(L, str);
+ lua_pushboolean(L, raw);
+
+ /* Check result. */
+ return engine_pcall(L, 2);
+}
+
+int engine_ipc(struct engine *engine, const char *expr)
+{
+ if (engine == NULL || engine->L == NULL) {
+ return kr_error(ENOEXEC);
+ }
+
+ /* Run expression and serialize response. */
+ engine_cmd(engine->L, expr, true);
+ if (lua_gettop(engine->L) > 0) {
+ l_tojson(engine->L);
+ return 1;
+ } else {
+ return 0;
+ }
+}
+
+int engine_load_sandbox(struct engine *engine)
+{
+ /* Init environment */
+ int ret = luaL_dofile(engine->L, LIBDIR "/sandbox-watcher.lua");
+ if (ret != 0) {
+ fprintf(stderr, "[system] error %s\n", lua_tostring(engine->L, -1));
+ lua_pop(engine->L, 1);
+ return kr_error(ENOEXEC);
+ }
+ ret = ffimodule_init(engine->L);
+ return ret;
+}
+
+int engine_loadconf(struct engine *engine, const char *config_path)
+{
+ assert(config_path != NULL);
+
+ char cwd[PATH_MAX];
+ get_workdir(cwd, sizeof(cwd));
+ kr_log_verbose("[system] loading config '%s' (workdir '%s')\n", config_path, cwd);
+
+ int ret = luaL_dofile(engine->L, config_path);
+ if (ret != 0) {
+ fprintf(stderr, "[system] error while loading config: "
+ "%s (workdir '%s')\n", lua_tostring(engine->L, -1), cwd);
+ lua_pop(engine->L, 1);
+ }
+ return ret;
+}
+
+int engine_start(struct engine *engine)
+{
+ /* Clean up stack */
+ lua_settop(engine->L, 0);
+
+ return kr_ok();
+}
+
+void engine_stop(struct engine *engine)
+{
+ if (!engine) {
+ return;
+ }
+ uv_stop(uv_default_loop());
+}
+
+/** @internal Find matching module */
+static size_t module_find(module_array_t *mod_list, const char *name)
+{
+ size_t found = mod_list->len;
+ for (size_t i = 0; i < mod_list->len; ++i) {
+ struct kr_module *mod = mod_list->at[i];
+ if (strcmp(mod->name, name) == 0) {
+ found = i;
+ break;
+ }
+ }
+ return found;
+}
+
+int engine_register(struct engine *engine, const char *name, const char *precedence, const char* ref)
+{
+ if (engine == NULL || name == NULL) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ /* Make sure module is unloaded */
+ (void) engine_unregister(engine, name);
+ /* Find the index of referenced module. */
+ module_array_t *mod_list = &engine->modules;
+ size_t ref_pos = mod_list->len;
+ if (precedence && ref) {
+ ref_pos = module_find(mod_list, ref);
+ if (ref_pos >= mod_list->len) {
+ return kr_error(EIDRM);
+ }
+ }
+ /* Attempt to load binary module */
+ struct kr_module *module = malloc(sizeof(*module));
+ if (!module) {
+ return kr_error(ENOMEM);
+ }
+ module->data = engine; /*< some outside modules may still use this value */
+
+ int ret = kr_module_load(module, name, LIBDIR "/kres_modules");
+ if (ret == 0) {
+ /* We have a C module, loaded and init() was called.
+ * Now we need to prepare the lua side. */
+ lua_State *L = engine->L;
+ lua_getglobal(L, "modules_create_table_for_c");
+ lua_pushpointer(L, module);
+ if (lua_isnil(L, -2)) {
+ /* When loading the three embedded modules, we don't
+ * have the "modules_*" lua function yet, but fortunately
+ * we don't need it there. Let's just check they're embedded.
+ * TODO: solve this better *without* breaking stuff. */
+ lua_pop(L, 2);
+ if (module->lib != RTLD_DEFAULT) {
+ ret = kr_error(1);
+ lua_pushliteral(L, "missing modules_create_table_for_c()");
+ }
+ } else {
+ ret = engine_pcall(L, 1);
+ }
+ if (ret) {
+ kr_log_error("[system] internal error when loading C module %s: %s\n",
+ module->name, lua_tostring(L, -1));
+ lua_pop(L, 1);
+ assert(false); /* probably not critical, but weird */
+ }
+
+ } else if (ret == kr_error(ENOENT)) {
+ /* No luck with C module, so try to load and .init() lua module. */
+ ret = ffimodule_register_lua(engine, module, name);
+ if (ret != 0) {
+ kr_log_error("[system] failed to load module '%s'\n", name);
+ }
+
+ } else if (ret == kr_error(ENOTSUP)) {
+ /* Print a more helpful message when module is linked against an old resolver ABI. */
+ kr_log_error("[system] module '%s' links to unsupported ABI, please rebuild it\n", name);
+ }
+
+ if (ret != 0) {
+ engine_unload(engine, module);
+ return ret;
+ }
+
+ /* Push to the right place in engine->modules */
+ if (array_push(engine->modules, module) < 0) {
+ engine_unload(engine, module);
+ return kr_error(ENOMEM);
+ }
+ if (precedence) {
+ struct kr_module **arr = mod_list->at;
+ size_t emplacement = mod_list->len;
+ if (strcasecmp(precedence, ">") == 0) {
+ if (ref_pos + 1 < mod_list->len)
+ emplacement = ref_pos + 1; /* Insert after target */
+ }
+ if (strcasecmp(precedence, "<") == 0) {
+ emplacement = ref_pos; /* Insert at target */
+ }
+ /* Move the tail if it has some elements. */
+ if (emplacement + 1 < mod_list->len) {
+ memmove(&arr[emplacement + 1], &arr[emplacement], sizeof(*arr) * (mod_list->len - (emplacement + 1)));
+ arr[emplacement] = module;
+ }
+ }
+
+ return kr_ok();
+}
+
+int engine_unregister(struct engine *engine, const char *name)
+{
+ module_array_t *mod_list = &engine->modules;
+ size_t found = module_find(mod_list, name);
+ if (found < mod_list->len) {
+ engine_unload(engine, mod_list->at[found]);
+ array_del(*mod_list, found);
+ return kr_ok();
+ }
+
+ return kr_error(ENOENT);
+}
+
+struct engine *engine_luaget(lua_State *L)
+{
+ lua_getglobal(L, "__engine");
+ struct engine *engine = lua_touserdata(L, -1);
+ if (!engine) luaL_error(L, "internal error, empty engine pointer");
+ lua_pop(L, 1);
+ return engine;
+}
--- /dev/null
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+/*
+ * @internal These are forward decls to allow building modules with engine but without Lua.
+ */
+struct lua_State;
+
+#include "lib/utils.h"
+#include "lib/resolve.h"
+#include "network.h"
+
+#include "watcher.h"
+
+/* @internal Array of file descriptors shorthand. */
+typedef array_t(int) fd_array_t;
+
+struct engine {
+ struct kr_context resolver;
+ struct watcher_context watcher;
+ struct network net;
+ module_array_t modules;
+ array_t(const struct kr_cdb_api *) backends;
+ fd_array_t ipc_set;
+ knot_mm_t *pool;
+ char *hostname;
+ struct lua_State *L;
+};
+
+int engine_init(struct engine *engine, knot_mm_t *pool);
+void engine_deinit(struct engine *engine);
+
+/** Perform a lua command within the sandbox.
+ *
+ * @return zero on success.
+ * The result will be returned on the lua stack - an error message in case of failure.
+ * http://www.lua.org/manual/5.1/manual.html#lua_pcall */
+int engine_cmd(struct lua_State *L, const char *str, bool raw);
+
+/** Execute current chunk in the sandbox */
+int engine_pcall(struct lua_State *L, int argc);
+
+int engine_ipc(struct engine *engine, const char *expr);
+
+
+int engine_load_sandbox(struct engine *engine);
+int engine_loadconf(struct engine *engine, const char *config_path);
+
+/** Start the lua engine and execute the config. */
+int engine_start(struct engine *engine);
+void engine_stop(struct engine *engine);
+int engine_register(struct engine *engine, const char *name, const char *precedence, const char* ref);
+int engine_unregister(struct engine *engine, const char *name);
+
+/** Return engine light userdata. */
+struct engine *engine_luaget(struct lua_State *L);
+
+/** Set/get the per engine hostname */
+char *engine_get_hostname(struct engine *engine);
+int engine_set_hostname(struct engine *engine, const char *hostname);
+
+/** Load root hints from a zonefile (or config-time default if NULL).
+ *
+ * @return error message or NULL (statically allocated)
+ * @note exported to be usable from the hints module.
+ */
+KR_EXPORT
+const char* engine_hint_root_file(struct kr_context *ctx, const char *file);
+
--- /dev/null
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <uv.h>
+#include <lua.h>
+#include <lauxlib.h>
+
+#include "bindings/impl.h"
+#include "engine.h"
+#include "ffimodule.h"
+#include "worker.h"
+#include "lib/module.h"
+#include "lib/layer.h"
+
+/** @internal Slots for layer callbacks.
+ * Each slot ID corresponds to Lua reference in module API. */
+enum {
+ SLOT_begin = 0,
+ SLOT_reset,
+ SLOT_finish,
+ SLOT_consume,
+ SLOT_produce,
+ SLOT_checkout,
+ SLOT_answer_finalize,
+ SLOT_count /* dummy, must be the last */
+};
+
+/** Lua registry indices for functions that wrap layer callbacks (shared by all lua modules). */
+static int l_ffi_wrap_slots[SLOT_count] = { 0 };
+
+/** @internal Continue with coroutine. */
+static void l_ffi_resume_cb(uv_idle_t *check)
+{
+ lua_State *L = check->data;
+ int status = lua_resume(L, 0);
+ if (status != LUA_YIELD) {
+ uv_idle_stop(check); /* Stop coroutine */
+ uv_close((uv_handle_t *)check, (uv_close_cb)free);
+ }
+ lua_pop(L, lua_gettop(L));
+}
+
+/** @internal Schedule deferred continuation. */
+static int l_ffi_defer(lua_State *L)
+{
+ uv_idle_t *check = malloc(sizeof(*check));
+ if (!check) {
+ return kr_error(ENOMEM);
+ }
+ uv_idle_init(uv_default_loop(), check);
+ check->data = L;
+ return uv_idle_start(check, l_ffi_resume_cb);
+}
+
+/** @internal Helper for calling the entrypoint, for kr_module functions. */
+static int l_ffi_call_mod(lua_State *L, int argc)
+{
+ int status = lua_pcall(L, argc, 1, 0);
+ if (status != 0) {
+ kr_log_error("error: %s\n", lua_tostring(L, -1));
+ lua_pop(L, 1);
+ return kr_error(EIO);
+ }
+ if (lua_isnumber(L, -1)) { /* Return code */
+ status = lua_tointeger(L, -1);
+ } else if (lua_isthread(L, -1)) { /* Continuations */
+ /* TODO: unused, possibly in a bad shape. Meant KR_STATE_YIELD? */
+ assert(!ENOTSUP);
+ status = l_ffi_defer(lua_tothread(L, -1));
+ }
+ lua_pop(L, 1);
+ return status;
+}
+
+/** Common part of calling modname.(de)init in lua.
+ * The function to call should be on top of the stack and it gets popped. */
+static int l_ffi_modcb(lua_State *L, struct kr_module *module)
+{
+ if (lua_isnil(L, -1)) {
+ lua_pop(L, 1); /* .(de)init == nil, maybe even the module table doesn't exist */
+ return kr_ok();
+ }
+ lua_getglobal(L, "modules_ffi_wrap_modcb");
+ lua_insert(L, -2); /* swap with .(de)init */
+ lua_pushpointer(L, module);
+ if (lua_pcall(L, 2, 0, 0) == 0)
+ return kr_ok();
+ kr_log_error("error: %s\n", lua_tostring(L, -1));
+ lua_pop(L, 1);
+ return kr_error(1);
+}
+
+static int l_ffi_deinit(struct kr_module *module)
+{
+ /* Call .deinit(), if it exists. */
+ lua_State *L = the_worker->engine->L;
+ lua_getglobal(L, module->name);
+ lua_getfield(L, -1, "deinit");
+ const int ret = l_ffi_modcb(L, module);
+ lua_pop(L, 1); /* the module's table */
+
+ const kr_layer_api_t *api = module->layer;
+ if (!api) {
+ return ret;
+ }
+ /* Unregister layer callback references from registry. */
+ for (int si = 0; si < SLOT_count; ++si) {
+ if (api->cb_slots[si] > 0) {
+ luaL_unref(L, LUA_REGISTRYINDEX, api->cb_slots[si]);
+ }
+ }
+ free_const(api);
+ return ret;
+}
+
+kr_layer_t kr_layer_t_static;
+
+/** @internal Helper for calling a layer Lua function by e.g. SLOT_begin. */
+static int l_ffi_call_layer(kr_layer_t *ctx, int slot_ix)
+{
+ const int wrap_slot = l_ffi_wrap_slots[slot_ix];
+ const int cb_slot = ctx->api->cb_slots[slot_ix];
+ assert(wrap_slot > 0 && cb_slot > 0);
+ lua_State *L = the_worker->engine->L;
+ lua_rawgeti(L, LUA_REGISTRYINDEX, wrap_slot);
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cb_slot);
+ /* We pass the content of *ctx via a global structure to avoid
+ * lua (full) userdata, as that's relatively expensive (GC-allocated).
+ * Performance: copying isn't ideal, but it's not visible in profiles. */
+ memcpy(&kr_layer_t_static, ctx, sizeof(*ctx));
+ const int ret = l_ffi_call_mod(L, 1);
+ /* The return codes are mixed at this point. We need to return KR_STATE_* */
+ return ret < 0 ? KR_STATE_FAIL : ret;
+}
+
+static int l_ffi_layer_begin(kr_layer_t *ctx)
+{
+ return l_ffi_call_layer(ctx, SLOT_begin);
+}
+
+static int l_ffi_layer_reset(kr_layer_t *ctx)
+{
+ return l_ffi_call_layer(ctx, SLOT_reset);
+}
+
+static int l_ffi_layer_finish(kr_layer_t *ctx)
+{
+ ctx->pkt = ctx->req->answer;
+ return l_ffi_call_layer(ctx, SLOT_finish);
+}
+
+static int l_ffi_layer_consume(kr_layer_t *ctx, knot_pkt_t *pkt)
+{
+ if (ctx->state & KR_STATE_FAIL) {
+ return ctx->state; /* Already failed, skip */
+ }
+ ctx->pkt = pkt;
+ return l_ffi_call_layer(ctx, SLOT_consume);
+}
+
+static int l_ffi_layer_produce(kr_layer_t *ctx, knot_pkt_t *pkt)
+{
+ if (ctx->state & KR_STATE_FAIL) {
+ return ctx->state; /* Already failed, skip */
+ }
+ ctx->pkt = pkt;
+ return l_ffi_call_layer(ctx, SLOT_produce);
+}
+
+static int l_ffi_layer_checkout(kr_layer_t *ctx, knot_pkt_t *pkt,
+ struct sockaddr *dst, int type)
+{
+ if (ctx->state & KR_STATE_FAIL) {
+ return ctx->state; /* Already failed, skip */
+ }
+ ctx->pkt = pkt;
+ ctx->dst = dst;
+ ctx->is_stream = (type == SOCK_STREAM);
+ return l_ffi_call_layer(ctx, SLOT_checkout);
+}
+
+static int l_ffi_layer_answer_finalize(kr_layer_t *ctx)
+{
+ return l_ffi_call_layer(ctx, SLOT_answer_finalize);
+}
+
+int ffimodule_init(lua_State *L)
+{
+ /* Wrappers defined in ./lua/sandbox.lua */
+ /* for API: (int state, kr_request_t *req) */
+ lua_getglobal(L, "modules_ffi_layer_wrap1");
+ const int wrap1 = luaL_ref(L, LUA_REGISTRYINDEX);
+ /* for API: (int state, kr_request_t *req, knot_pkt_t *) */
+ lua_getglobal(L, "modules_ffi_layer_wrap2");
+ const int wrap2 = luaL_ref(L, LUA_REGISTRYINDEX);
+ lua_getglobal(L, "modules_ffi_layer_wrap_checkout");
+ const int wrap_checkout = luaL_ref(L, LUA_REGISTRYINDEX);
+ if (wrap1 == LUA_REFNIL || wrap2 == LUA_REFNIL || wrap_checkout == LUA_REFNIL) {
+ return kr_error(ENOENT);
+ }
+
+ const int slots[SLOT_count] = {
+ [SLOT_begin] = wrap1,
+ [SLOT_reset] = wrap1,
+ [SLOT_finish] = wrap2,
+ [SLOT_consume] = wrap2,
+ [SLOT_produce] = wrap2,
+ [SLOT_checkout] = wrap_checkout,
+ [SLOT_answer_finalize] = wrap1,
+ };
+ memcpy(l_ffi_wrap_slots, slots, sizeof(l_ffi_wrap_slots));
+ return kr_ok();
+}
+void ffimodule_deinit(lua_State *L)
+{
+ /* Unref each wrapper function from lua.
+ * It's probably useless, as we're about to destroy lua_State, but... */
+ const int wrapsIndices[] = {
+ SLOT_begin,
+ SLOT_consume,
+ SLOT_checkout,
+ };
+ for (int i = 0; i < sizeof(wrapsIndices) / sizeof(wrapsIndices[0]); ++i) {
+ luaL_unref(L, LUA_REGISTRYINDEX, l_ffi_wrap_slots[wrapsIndices[i]]);
+ }
+}
+
+/** @internal Conditionally register layer trampoline
+ * @warning Expects 'module.layer' to be on top of Lua stack. */
+#define LAYER_REGISTER(L, api, name) do { \
+ int *cb_slot = (api)->cb_slots + SLOT_ ## name; \
+ lua_getfield((L), -1, #name); \
+ if (!lua_isnil((L), -1)) { \
+ (api)->name = l_ffi_layer_ ## name; \
+ *cb_slot = luaL_ref((L), LUA_REGISTRYINDEX); \
+ } else { \
+ lua_pop((L), 1); \
+ } \
+} while(0)
+
+/** @internal Create C layer api wrapper. */
+static kr_layer_api_t *l_ffi_layer_create(lua_State *L, struct kr_module *module)
+{
+ /* Fabricate layer API wrapping the Lua functions
+ * reserve slots after it for references to Lua callbacks. */
+ const size_t api_length = offsetof(kr_layer_api_t, cb_slots)
+ + (SLOT_count * sizeof(module->layer->cb_slots[0]));
+ kr_layer_api_t *api = malloc(api_length);
+ if (api) {
+ memset(api, 0, api_length);
+ LAYER_REGISTER(L, api, begin);
+ LAYER_REGISTER(L, api, finish);
+ LAYER_REGISTER(L, api, consume);
+ LAYER_REGISTER(L, api, produce);
+ LAYER_REGISTER(L, api, checkout);
+ LAYER_REGISTER(L, api, answer_finalize);
+ LAYER_REGISTER(L, api, reset);
+ }
+ return api;
+}
+
+#undef LAYER_REGISTER
+
+int ffimodule_register_lua(struct engine *engine, struct kr_module *module, const char *name)
+{
+ /* Register module in Lua */
+ lua_State *L = engine->L;
+ lua_getglobal(L, "require");
+ lua_pushfstring(L, "kres_modules.%s", name);
+ if (lua_pcall(L, 1, LUA_MULTRET, 0) != 0) {
+ kr_log_error("error: %s\n", lua_tostring(L, -1));
+ lua_pop(L, 1);
+ return kr_error(ENOENT);
+ }
+ lua_setglobal(L, name);
+ lua_getglobal(L, name);
+
+ /* Create FFI module with trampolined functions. */
+ memset(module, 0, sizeof(*module));
+ module->name = strdup(name);
+ module->deinit = &l_ffi_deinit;
+ /* Bake layer API if defined in module */
+ lua_getfield(L, -1, "layer");
+ if (!lua_isnil(L, -1)) {
+ module->layer = l_ffi_layer_create(L, module);
+ }
+ lua_pop(L, 1); /* .layer table */
+
+ /* Now call .init(), if it exists. */
+ lua_getfield(L, -1, "init");
+ const int ret = l_ffi_modcb(L, module);
+ lua_pop(L, 1); /* the module's table */
+ return ret;
+}
--- /dev/null
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include "lib/defines.h"
+#include "lib/layer.h"
+#include <lua.h>
+struct engine;
+struct kr_module;
+
+/**
+ * Register Lua module as a FFI module.
+ * This fabricates a standard module interface,
+ * that trampolines to the Lua module methods.
+ *
+ * @note Lua module is loaded in it's own coroutine,
+ * so it's possible to yield and resume at arbitrary
+ * places except deinit()
+ *
+ * @param engine daemon engine
+ * @param module prepared module
+ * @param name module name
+ * @return 0 or an error
+ */
+int ffimodule_register_lua(struct engine *engine, struct kr_module *module, const char *name);
+
+int ffimodule_init(lua_State *L);
+void ffimodule_deinit(lua_State *L);
+
+/** Static storage for faster passing of layer function parameters to lua callbacks.
+ *
+ * We don't need to declare it in a header, but let's give it visibility. */
+KR_EXPORT kr_layer_t kr_layer_t_static;
+
--- /dev/null
+/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <string.h>
+#include <libknot/errcode.h>
+#include <contrib/ucw/lib.h>
+#include <contrib/ucw/mempool.h>
+#include <assert.h>
+
+#include "io.h"
+#include "network.h"
+#include "worker.h"
+#include "tls.h"
+#include "session.h"
+
+#define negotiate_bufsize(func, handle, bufsize_want) do { \
+ int bufsize = 0; (func)((handle), &bufsize); \
+ if (bufsize < (bufsize_want)) { \
+ bufsize = (bufsize_want); \
+ (func)((handle), &bufsize); \
+ } \
+} while (0)
+
+static void check_bufsize(uv_handle_t* handle)
+{
+ return; /* TODO: resurrect after https://github.com/libuv/libuv/issues/419 */
+ /* We want to buffer at least N waves in advance.
+ * This is magic presuming we can pull in a whole recvmmsg width in one wave.
+ * Linux will double this the bufsize wanted.
+ */
+ const int bufsize_want = 2 * sizeof( ((struct worker_ctx *)NULL)->wire_buf ) ;
+ negotiate_bufsize(uv_recv_buffer_size, handle, bufsize_want);
+ negotiate_bufsize(uv_send_buffer_size, handle, bufsize_want);
+}
+
+#undef negotiate_bufsize
+
+static void handle_getbuf(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf)
+{
+ /* UDP sessions use worker buffer for wire data,
+ * TCP sessions use session buffer for wire data
+ * (see session_set_handle()).
+ * TLS sessions use buffer from TLS context.
+ * The content of the worker buffer is
+ * guaranteed to be unchanged only for the duration of
+ * udp_read() and tcp_read().
+ */
+ struct session *s = handle->data;
+ if (!session_flags(s)->has_tls) {
+ buf->base = (char *) session_wirebuf_get_free_start(s);
+ buf->len = session_wirebuf_get_free_size(s);
+ } else {
+ struct tls_common_ctx *ctx = session_tls_get_common_ctx(s);
+ buf->base = (char *) ctx->recv_buf;
+ buf->len = sizeof(ctx->recv_buf);
+ }
+}
+
+void udp_recv(uv_udp_t *handle, ssize_t nread, const uv_buf_t *buf,
+ const struct sockaddr *addr, unsigned flags)
+{
+ uv_loop_t *loop = handle->loop;
+ struct worker_ctx *worker = loop->data;
+ struct session *s = handle->data;
+ if (session_flags(s)->closing) {
+ return;
+ }
+ if (nread <= 0) {
+ if (nread < 0) { /* Error response, notify resolver */
+ worker_submit(s, NULL, NULL);
+ } /* nread == 0 is for freeing buffers, we don't need to do this */
+ return;
+ }
+ if (addr->sa_family == AF_UNSPEC) {
+ return;
+ }
+ if (session_flags(s)->outgoing) {
+ const struct sockaddr *peer = session_get_peer(s);
+ assert(peer->sa_family != AF_UNSPEC);
+ if (kr_sockaddr_cmp(peer, addr) != 0) {
+ kr_log_verbose("[io] <= ignoring UDP from unexpected address '%s'\n",
+ kr_straddr(addr));
+ return;
+ }
+ }
+ ssize_t consumed = session_wirebuf_consume(s, (const uint8_t *)buf->base,
+ nread);
+ assert(consumed == nread); (void)consumed;
+ session_wirebuf_process(s, addr);
+ session_wirebuf_discard(s);
+ mp_flush(worker->pkt_pool.ctx);
+}
+
+static int family_to_freebind_option(sa_family_t sa_family, int *level, int *name)
+{
+ switch (sa_family) {
+ case AF_INET:
+ *level = IPPROTO_IP;
+#if defined(IP_FREEBIND)
+ *name = IP_FREEBIND;
+#elif defined(IP_BINDANY)
+ *name = IP_BINDANY;
+#else
+ return kr_error(ENOTSUP);
+#endif
+ break;
+ case AF_INET6:
+#if defined(IP_FREEBIND)
+ *level = IPPROTO_IP;
+ *name = IP_FREEBIND;
+#elif defined(IPV6_BINDANY)
+ *level = IPPROTO_IPV6;
+ *name = IPV6_BINDANY;
+#else
+ return kr_error(ENOTSUP);
+#endif
+ break;
+ default:
+ return kr_error(ENOTSUP);
+ }
+ return kr_ok();
+}
+
+int io_bind(const struct sockaddr *addr, int type, const endpoint_flags_t *flags)
+{
+ const int fd = socket(addr->sa_family, type, 0);
+ if (fd < 0) return kr_error(errno);
+
+ int yes = 1;
+ if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)))
+ return kr_error(errno);
+
+#ifdef SO_REUSEPORT_LB
+ if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT_LB, &yes, sizeof(yes)))
+ return kr_error(errno);
+#elif defined(SO_REUSEPORT) && defined(__linux__) /* different meaning on (Free)BSD */
+ if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &yes, sizeof(yes)))
+ return kr_error(errno);
+#endif
+
+#ifdef IPV6_V6ONLY
+ if (addr->sa_family == AF_INET6
+ && setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &yes, sizeof(yes)))
+ return kr_error(errno);
+#endif
+ if (flags != NULL && flags->freebind) {
+ int optlevel;
+ int optname;
+ int ret = family_to_freebind_option(addr->sa_family, &optlevel, &optname);
+ if (ret) return kr_error(ret);
+ if (setsockopt(fd, optlevel, optname, &yes, sizeof(yes)))
+ return kr_error(errno);
+ }
+
+ if (bind(fd, addr, kr_sockaddr_len(addr)))
+ return kr_error(errno);
+
+ return fd;
+}
+
+int io_listen_udp(uv_loop_t *loop, uv_udp_t *handle, int fd)
+{
+ if (!handle) {
+ return kr_error(EINVAL);
+ }
+ int ret = uv_udp_init(loop, handle);
+ if (ret) return ret;
+
+ ret = uv_udp_open(handle, fd);
+ if (ret) return ret;
+
+ uv_handle_t *h = (uv_handle_t *)handle;
+ check_bufsize(h);
+ /* Handle is already created, just create context. */
+ struct session *s = session_new(h, false);
+ assert(s);
+ session_flags(s)->outgoing = false;
+
+ int socklen = sizeof(union inaddr);
+ ret = uv_udp_getsockname(handle, session_get_sockname(s), &socklen);
+ if (ret) {
+ kr_log_error("ERROR: getsockname failed: %s\n", uv_strerror(ret));
+ abort(); /* It might be nontrivial not to leak something here. */
+ }
+
+ return io_start_read(h);
+}
+
+void tcp_timeout_trigger(uv_timer_t *timer)
+{
+ struct session *s = timer->data;
+
+ assert(!session_flags(s)->closing);
+
+ struct worker_ctx *worker = timer->loop->data;
+
+ if (!session_tasklist_is_empty(s)) {
+ int finalized = session_tasklist_finalize_expired(s);
+ worker->stats.timeout += finalized;
+ /* session_tasklist_finalize_expired() may call worker_task_finalize().
+ * If session is a source session and there were IO errors,
+ * worker_task_finalize() can filnalize all tasks and close session. */
+ if (session_flags(s)->closing) {
+ return;
+ }
+
+ }
+ if (!session_tasklist_is_empty(s)) {
+ uv_timer_stop(timer);
+ session_timer_start(s, tcp_timeout_trigger,
+ KR_RESOLVE_TIME_LIMIT / 2,
+ KR_RESOLVE_TIME_LIMIT / 2);
+ } else {
+ /* Normally it should not happen,
+ * but better to check if there anything in this list. */
+ while (!session_waitinglist_is_empty(s)) {
+ struct qr_task *t = session_waitinglist_pop(s, false);
+ worker_task_finalize(t, KR_STATE_FAIL);
+ worker_task_unref(t);
+ worker->stats.timeout += 1;
+ if (session_flags(s)->closing) {
+ return;
+ }
+ }
+ const struct engine *engine = worker->engine;
+ const struct network *net = &engine->net;
+ uint64_t idle_in_timeout = net->tcp.in_idle_timeout;
+ uint64_t last_activity = session_last_activity(s);
+ uint64_t idle_time = kr_now() - last_activity;
+ if (idle_time < idle_in_timeout) {
+ idle_in_timeout -= idle_time;
+ uv_timer_stop(timer);
+ session_timer_start(s, tcp_timeout_trigger,
+ idle_in_timeout, idle_in_timeout);
+ } else {
+ struct sockaddr *peer = session_get_peer(s);
+ char *peer_str = kr_straddr(peer);
+ kr_log_verbose("[io] => closing connection to '%s'\n",
+ peer_str ? peer_str : "");
+ if (session_flags(s)->outgoing) {
+ worker_del_tcp_waiting(worker, peer);
+ worker_del_tcp_connected(worker, peer);
+ }
+ session_close(s);
+ }
+ }
+}
+
+static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
+{
+ struct session *s = handle->data;
+ assert(s && session_get_handle(s) == (uv_handle_t *)handle &&
+ handle->type == UV_TCP);
+
+ if (session_flags(s)->closing) {
+ return;
+ }
+
+ /* nread might be 0, which does not indicate an error or EOF.
+ * This is equivalent to EAGAIN or EWOULDBLOCK under read(2). */
+ if (nread == 0) {
+ return;
+ }
+
+ if (nread < 0 || !buf->base) {
+ if (kr_verbose_status) {
+ struct sockaddr *peer = session_get_peer(s);
+ char *peer_str = kr_straddr(peer);
+ kr_log_verbose("[io] => connection to '%s' closed by peer (%s)\n",
+ peer_str ? peer_str : "",
+ uv_strerror(nread));
+ }
+ worker_end_tcp(s);
+ return;
+ }
+
+ ssize_t consumed = 0;
+ const uint8_t *data = (const uint8_t *)buf->base;
+ ssize_t data_len = nread;
+ if (session_flags(s)->has_tls) {
+ /* buf->base points to start of the tls receive buffer.
+ Decode data free space in session wire buffer. */
+ consumed = tls_process_input_data(s, (const uint8_t *)buf->base, nread);
+ if (consumed < 0) {
+ if (kr_verbose_status) {
+ struct sockaddr *peer = session_get_peer(s);
+ char *peer_str = kr_straddr(peer);
+ kr_log_verbose("[io] => connection to '%s': "
+ "error processing TLS data, close\n",
+ peer_str ? peer_str : "");
+ }
+ worker_end_tcp(s);
+ return;
+ } else if (consumed == 0) {
+ return;
+ }
+ data = session_wirebuf_get_free_start(s);
+ data_len = consumed;
+ }
+
+ /* data points to start of the free space in session wire buffer.
+ Simple increase internal counter. */
+ consumed = session_wirebuf_consume(s, data, data_len);
+ assert(consumed == data_len);
+
+ int ret = session_wirebuf_process(s, session_get_peer(s));
+ if (ret < 0) {
+ /* An error has occurred, close the session. */
+ worker_end_tcp(s);
+ }
+ session_wirebuf_compress(s);
+ struct worker_ctx *worker = handle->loop->data;
+ mp_flush(worker->pkt_pool.ctx);
+}
+
+static void _tcp_accept(uv_stream_t *master, int status, bool tls)
+{
+ if (status != 0) {
+ return;
+ }
+
+ struct worker_ctx *worker = the_worker;
+ uv_tcp_t *client = malloc(sizeof(uv_tcp_t));
+ if (!client) {
+ return;
+ }
+ int res = io_create(master->loop, (uv_handle_t *)client,
+ SOCK_STREAM, AF_UNSPEC, tls);
+ if (res) {
+ if (res == UV_EMFILE) {
+ worker->too_many_open = true;
+ worker->rconcurrent_highwatermark = worker->stats.rconcurrent;
+ }
+ /* Since res isn't OK struct session wasn't allocated \ borrowed.
+ * We must release client handle only.
+ */
+ free(client);
+ return;
+ }
+
+ /* struct session was allocated \ borrowed from memory pool. */
+ struct session *s = client->data;
+ assert(session_flags(s)->outgoing == false);
+ assert(session_flags(s)->has_tls == tls);
+
+ if (uv_accept(master, (uv_stream_t *)client) != 0) {
+ /* close session, close underlying uv handles and
+ * deallocate (or return to memory pool) memory. */
+ session_close(s);
+ return;
+ }
+
+ /* Get peer's and our address. We apparently get specific sockname here
+ * even if we listened on a wildcard address. */
+ struct sockaddr *sa = session_get_peer(s);
+ int sa_len = sizeof(struct sockaddr_in6);
+ int ret = uv_tcp_getpeername(client, sa, &sa_len);
+ if (ret || sa->sa_family == AF_UNSPEC) {
+ session_close(s);
+ return;
+ }
+ sa = session_get_sockname(s);
+ sa_len = sizeof(struct sockaddr_in6);
+ ret = uv_tcp_getsockname(client, sa, &sa_len);
+ if (ret || sa->sa_family == AF_UNSPEC) {
+ session_close(s);
+ return;
+ }
+
+ /* Set deadlines for TCP connection and start reading.
+ * It will re-check every half of a request time limit if the connection
+ * is idle and should be terminated, this is an educated guess. */
+
+ const struct network *net = &worker->engine->net;
+ uint64_t idle_in_timeout = net->tcp.in_idle_timeout;
+
+ uint64_t timeout = KR_CONN_RTT_MAX / 2;
+ if (tls) {
+ timeout += TLS_MAX_HANDSHAKE_TIME;
+ struct tls_ctx_t *ctx = session_tls_get_server_ctx(s);
+ if (!ctx) {
+ ctx = tls_new(worker);
+ if (!ctx) {
+ session_close(s);
+ return;
+ }
+ ctx->c.session = s;
+ ctx->c.handshake_state = TLS_HS_IN_PROGRESS;
+ session_tls_set_server_ctx(s, ctx);
+ }
+ }
+ session_timer_start(s, tcp_timeout_trigger, timeout, idle_in_timeout);
+ io_start_read((uv_handle_t *)client);
+}
+
+static void tcp_accept(uv_stream_t *master, int status)
+{
+ _tcp_accept(master, status, false);
+}
+
+static void tls_accept(uv_stream_t *master, int status)
+{
+ _tcp_accept(master, status, true);
+}
+
+int io_listen_tcp(uv_loop_t *loop, uv_tcp_t *handle, int fd, int tcp_backlog, bool has_tls)
+{
+ const uv_connection_cb connection = has_tls ? tls_accept : tcp_accept;
+ if (!handle) {
+ return kr_error(EINVAL);
+ }
+ int ret = uv_tcp_init(loop, handle);
+ if (ret) return ret;
+
+ ret = uv_tcp_open(handle, (uv_os_sock_t) fd);
+ if (ret) return ret;
+
+ int val; (void)val;
+ /* TCP_DEFER_ACCEPT delays accepting connections until there is readable data. */
+#ifdef TCP_DEFER_ACCEPT
+ val = KR_CONN_RTT_MAX/1000;
+ if (setsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, &val, sizeof(val))) {
+ kr_log_error("[ io ] listen TCP (defer_accept): %s\n", strerror(errno));
+ }
+#endif
+
+ ret = uv_listen((uv_stream_t *)handle, tcp_backlog, connection);
+ if (ret != 0) {
+ return ret;
+ }
+
+ /* TCP_FASTOPEN enables 1 RTT connection resumptions. */
+#ifdef TCP_FASTOPEN
+ #ifdef __linux__
+ val = 16; /* Accepts queue length hint */
+ #else
+ val = 1; /* Accepts on/off */
+ #endif
+ if (setsockopt(fd, IPPROTO_TCP, TCP_FASTOPEN, &val, sizeof(val))) {
+ kr_log_error("[ io ] listen TCP (fastopen): %s\n", strerror(errno));
+ }
+#endif
+
+ handle->data = NULL;
+ return 0;
+}
+
+int io_create(uv_loop_t *loop, uv_handle_t *handle, int type, unsigned family, bool has_tls)
+{
+ int ret = -1;
+ if (type == SOCK_DGRAM) {
+ ret = uv_udp_init(loop, (uv_udp_t *)handle);
+ } else if (type == SOCK_STREAM) {
+ ret = uv_tcp_init_ex(loop, (uv_tcp_t *)handle, family);
+ uv_tcp_nodelay((uv_tcp_t *)handle, 1);
+ }
+ if (ret != 0) {
+ return ret;
+ }
+ struct session *s = session_new(handle, has_tls);
+ if (s == NULL) {
+ ret = -1;
+ }
+ return ret;
+}
+
+void io_deinit(uv_handle_t *handle)
+{
+ if (!handle) {
+ return;
+ }
+ session_free(handle->data);
+ handle->data = NULL;
+}
+
+void io_free(uv_handle_t *handle)
+{
+ io_deinit(handle);
+ free(handle);
+}
+
+int io_start_read(uv_handle_t *handle)
+{
+ switch (handle->type) {
+ case UV_UDP:
+ return uv_udp_recv_start((uv_udp_t *)handle, &handle_getbuf, &udp_recv);
+ case UV_TCP:
+ return uv_read_start((uv_stream_t *)handle, &handle_getbuf, &tcp_recv);
+ default:
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+}
+
+int io_stop_read(uv_handle_t *handle)
+{
+ if (handle->type == UV_UDP) {
+ return uv_udp_recv_stop((uv_udp_t *)handle);
+ } else {
+ return uv_read_stop((uv_stream_t *)handle);
+ }
+}
--- /dev/null
+/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <uv.h>
+#include <libknot/packet/pkt.h>
+#include <gnutls/gnutls.h>
+#include "lib/generic/array.h"
+#include "worker.h"
+
+struct tls_ctx_t;
+struct tls_client_ctx_t;
+
+/** Bind address into a file-descriptor (only, no libuv). type is e.g. SOCK_DGRAM */
+int io_bind(const struct sockaddr *addr, int type, const endpoint_flags_t *flags);
+/** Initialize a UDP handle and start listening. */
+int io_listen_udp(uv_loop_t *loop, uv_udp_t *handle, int fd);
+/** Initialize a TCP handle and start listening. */
+int io_listen_tcp(uv_loop_t *loop, uv_tcp_t *handle, int fd, int tcp_backlog, bool has_tls);
+
+void tcp_timeout_trigger(uv_timer_t *timer);
+
+/** Initialize the handle, incl. ->data = struct session * instance.
+ * \param type = SOCK_*
+ * \param family = AF_*
+ * \param has_tls has meanings only when type is SOCK_STREAM */
+int io_create(uv_loop_t *loop, uv_handle_t *handle, int type,
+ unsigned family, bool has_tls);
+void io_deinit(uv_handle_t *handle);
+void io_free(uv_handle_t *handle);
+
+int io_start_read(uv_handle_t *handle);
+int io_stop_read(uv_handle_t *handle);
--- /dev/null
+-- Open cache if not set/disabled
+if not cache.current_size then
+ cache.size = 100 * MB
+end
+
--- /dev/null
+# kres-watcher: lua modules
+
+sandbox_watcher = configure_file(
+ input: 'sandbox-watcher.lua.in',
+ output: 'sandbox-watcher.lua',
+ configuration: lua_config,
+)
+
+lua_watcher_src = [
+ sandbox_watcher,
+ files('config-watcher.lua'),
+]
+
+# install daemon lua sources
+install_data(
+ lua_watcher_src,
+ install_dir: lib_dir,
+)
\ No newline at end of file
--- /dev/null
+local debug = require('debug')
+local ffi = require('ffi')
+
+-- Units
+kB = 1024
+MB = 1024*kB
+GB = 1024*MB
+-- Time
+sec = 1000
+second = sec
+minute = 60 * sec
+min = minute
+hour = 60 * minute
+day = 24 * hour
+
+-- Logging
+function panic(fmt, ...)
+ print(debug.traceback('error occured here (config filename:lineno is '
+ .. 'at the bottom, if config is involved):', 2))
+ error(string.format('ERROR: '.. fmt, ...), 0)
+end
+function warn(fmt, ...)
+ io.stderr:write(string.format(fmt..'\n', ...))
+end
+function log(fmt, ...)
+ print(string.format(fmt, ...))
+end
+
+-- Resolver bindings
+kres = require('kres')
+if rawget(kres, 'str2dname') ~= nil then
+ todname = kres.str2dname
+end
+
+worker.resolve_pkt = function (pkt, options, finish, init)
+ options = kres.mk_qflags(options)
+ local task = ffi.C.worker_resolve_start(pkt, options)
+
+ -- Deal with finish and init callbacks
+ if finish ~= nil then
+ local finish_cb
+ finish_cb = ffi.cast('trace_callback_f',
+ function (req)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ finish(req.answer, req)
+ finish_cb:free()
+ end)
+ task.ctx.req.trace_finish = finish_cb
+ end
+ if init ~= nil then
+ init(task.ctx.req)
+ end
+
+ return ffi.C.worker_resolve_exec(task, pkt) == 0
+end
+
+worker.resolve = function (qname, qtype, qclass, options, finish, init)
+ -- Alternatively use named arguments
+ if type(qname) == 'table' then
+ local t = qname
+ qname = t.name
+ qtype = t.type
+ qclass = t.class
+ options = t.options
+ finish = t.finish
+ init = t.init
+ end
+ qtype = qtype or kres.type.A
+ qclass = qclass or kres.class.IN
+ options = kres.mk_qflags(options)
+ -- LATER: nicer errors for rubbish in qname, qtype, qclass?
+ local pkt = ffi.C.worker_resolve_mk_pkt(qname, qtype, qclass, options)
+ if pkt == nil then
+ panic('failure in worker.resolve(); probably invalid qname "%s"', qname)
+ end
+ local ret = worker.resolve_pkt(pkt, options, finish, init)
+ ffi.C.knot_rrset_free(pkt.opt_rr, nil);
+ ffi.C.knot_pkt_free(pkt);
+ return ret
+end
+resolve = worker.resolve
+
+-- Shorthand for aggregated per-worker information
+worker.info = function ()
+ local t = worker.stats()
+ t.pid = worker.pid
+ return t
+end
+
+-- Resolver mode of operation
+local current_mode = 'normal'
+local mode_table = { normal=0, strict=1, permissive=2 }
+function mode(m)
+ if not m then return current_mode end
+ if not mode_table[m] then error('unsupported mode: '..m) end
+ -- Update current operation mode
+ current_mode = m
+ option('STRICT', current_mode == 'strict')
+ option('PERMISSIVE', current_mode == 'permissive')
+ return true
+end
+
+-- Trivial option alias
+function reorder_RR(val)
+ return option('REORDER_RR', val)
+end
+
+-- Get/set resolver options via name (string)
+function option(name, val)
+ local flags = kres.context().options;
+ -- Note: no way to test existence of flags[name] but we want error anyway.
+ name = string.upper(name) -- convenience
+ if val ~= nil then
+ if (val ~= true) and (val ~= false) then
+ panic('invalid option value: ' .. tostring(val))
+ end
+ flags[name] = val;
+ end
+ return flags[name];
+end
+
+-- Function aliases
+-- `env.VAR returns os.getenv(VAR)`
+env = {}
+setmetatable(env, {
+ __index = function (_, k) return os.getenv(k) end
+})
+
+-- Quick access to interfaces
+-- `net.<iface>` => `net.interfaces()[iface]`
+-- `net = {addr1, ..}` => `net.listen(name, addr1)`
+-- `net.ipv{4,6} = {true, false}` => enable/disable IPv{4,6}
+setmetatable(net, {
+ __index = function (t, k)
+ local v = rawget(t, k)
+ if v then return v
+ elseif k == 'ipv6' then return not option('NO_IPV6')
+ elseif k == 'ipv4' then return not option('NO_IPV4')
+ else return net.interfaces()[k]
+ end
+ end,
+ __newindex = function (t,k,v)
+ if k == 'ipv6' then return option('NO_IPV6', not v)
+ elseif k == 'ipv4' then return option('NO_IPV4', not v)
+ else
+ local iname = rawget(net.interfaces(), v)
+ if iname then t.listen(iname)
+ else t.listen(v)
+ end
+ end
+ end
+})
+
+-- Syntactic sugar for module loading
+-- `modules.<name> = <config>`
+setmetatable(modules, {
+ __newindex = function (_, k, v)
+ if type(k) == 'number' then
+ k, v = v, nil
+ end
+ if not rawget(_G, k) then
+ modules.load(k)
+ k = string.match(k, '[%w_]+')
+ local mod = _G[k]
+ local config = mod and rawget(mod, 'config')
+ if mod ~= nil and config ~= nil then
+ if k ~= v then config(v)
+ else config()
+ end
+ end
+ end
+ end
+})
+
+-- Set up lua table for a C module. (Internal function.)
+function modules_create_table_for_c(kr_module_ud)
+ local kr_module = ffi.cast('struct kr_module **', kr_module_ud)[0]
+ --- Set up the global table named according to the module.
+ if kr_module.config == nil and kr_module.props == nil then
+ return
+ end
+ local module = {}
+ local module_name = ffi.string(kr_module.name)
+ _G[module_name] = module
+
+ --- Construct lua functions for properties.
+ if kr_module.props ~= nil then
+ local i = 0
+ while true do
+ local prop = kr_module.props[i]
+ local cb = prop.cb
+ if cb == nil then break; end
+ module[ffi.string(prop.name)] =
+ function (arg) -- lua wrapper around kr_prop_cb function typedef
+ local arg_conv
+ if type(arg) == 'table' or type(arg) == 'boolean' then
+ arg_conv = tojson(arg)
+ elseif arg ~= nil then
+ arg_conv = tostring(arg)
+ end
+ local ret_cstr = cb(__engine, kr_module, arg_conv)
+ if ret_cstr == nil then
+ return nil
+ end
+ -- LATER(optim.): superfluous copying
+ local ret_str = ffi.string(ret_cstr)
+ -- This is a bit ugly, but the API is that invalid JSON
+ -- should be just returned as string :-(
+ local status, ret = pcall(fromjson, ret_str)
+ if not status then ret = ret_str end
+ ffi.C.free(ret_cstr)
+ return ret
+ end
+ i = i + 1
+ end
+ end
+
+ --- Construct lua function for config().
+ if kr_module.config ~= nil then
+ module.config =
+ function (arg)
+ local arg_conv
+ if type(arg) == 'table' or type(arg) == 'boolean' then
+ arg_conv = tojson(arg)
+ elseif arg ~= nil then
+ arg_conv = tostring(arg)
+ end
+ return kr_module.config(kr_module, arg_conv)
+ end
+ end
+
+ --- Add syntactic sugar for get() and set() properties.
+ --- That also "catches" any commands like `moduleName.foo = bar`.
+ local m_index, m_newindex
+ local get_f = rawget(module, 'get')
+ if get_f ~= nil then
+ m_index = function (_, key)
+ return get_f(key)
+ end
+ else
+ m_index = function ()
+ error('module ' .. module_name .. ' does not support indexing syntax sugar')
+ end
+ end
+ local set_f = rawget(module, 'set')
+ if set_f ~= nil then
+ m_newindex = function (_, key, value)
+ -- This will produce a nasty error on some non-string parameters.
+ -- Still, we already use it with integer values, e.g. in predict module :-/
+ return set_f(key .. ' ' .. value)
+ end
+ else
+ m_newindex = function ()
+ error('module ' .. module_name .. ' does not support assignment syntax sugar')
+ end
+ end
+ setmetatable(module, {
+ -- note: the two functions only get called for *missing* indices
+ __index = m_index,
+ __newindex = m_newindex,
+ })
+end
+
+local layer_ctx = ffi.C.kr_layer_t_static
+-- Utilities internal for lua layer glue; see ../ffimodule.c
+modules_ffi_layer_wrap1 = function (layer_cb)
+ return layer_cb(layer_ctx.state, layer_ctx.req)
+end
+modules_ffi_layer_wrap2 = function (layer_cb)
+ return layer_cb(layer_ctx.state, layer_ctx.req, layer_ctx.pkt)
+end
+modules_ffi_layer_wrap_checkout = function (layer_cb)
+ return layer_cb(layer_ctx.state, layer_ctx.req, layer_ctx.pkt,
+ layer_ctx.dst, layer_ctx.is_stream)
+end
+modules_ffi_wrap_modcb = function (cb, kr_module_ud) -- this one isn't for layer
+ local kr_module = ffi.cast('struct kr_module **', kr_module_ud)[0]
+ return cb(kr_module)
+end
+
+cache.clear = function (name, exact_name, rr_type, chunk_size, callback, prev_state)
+ if name == nil or (name == '.' and not exact_name) then
+ -- keep same output format as for 'standard' clear
+ local total_count = cache.count()
+ if not cache.clear_everything() then
+ error('unable to clear everything')
+ end
+ return {count = total_count}
+ end
+ -- Check parameters, in order, and set defaults if missing.
+ local dname = kres.str2dname(name)
+ if not dname then error('cache.clear(): incorrect name passed') end
+ if exact_name == nil then exact_name = false end
+ if type(exact_name) ~= 'boolean'
+ then error('cache.clear(): incorrect exact_name passed') end
+
+ local cach = kres.context().cache;
+ local rettable = {}
+ -- Apex warning. If the caller passes a custom callback,
+ -- we assume they are advanced enough not to need the check.
+ -- The point is to avoid repeating the check in each callback iteration.
+ if callback == nil then
+ local apex_array = ffi.new('knot_dname_t *[1]') -- C: dname **apex_array
+ local ret = ffi.C.kr_cache_closest_apex(cach, dname, false, apex_array)
+ if ret < 0 then
+ error(ffi.string(ffi.C.knot_strerror(ret))) end
+ if not ffi.C.knot_dname_is_equal(apex_array[0], dname) then
+ local apex_str = kres.dname2str(apex_array[0])
+ rettable.not_apex = 'to clear proofs of non-existence call '
+ .. 'cache.clear(\'' .. tostring(apex_str) ..'\')'
+ rettable.subtree = apex_str
+ end
+ ffi.C.free(apex_array[0])
+ end
+
+ if rr_type ~= nil then
+ -- Special case, without any subtree searching.
+ if not exact_name
+ then error('cache.clear(): specifying rr_type only supported with exact_name') end
+ if chunk_size or callback
+ then error('cache.clear(): chunk_size and callback parameters not supported with rr_type') end
+ local ret = ffi.C.kr_cache_remove(cach, dname, rr_type)
+ if ret < 0 then error(ffi.string(ffi.C.knot_strerror(ret))) end
+ return {count = 1}
+ end
+
+ if chunk_size == nil then chunk_size = 100 end
+ if type(chunk_size) ~= 'number' or chunk_size <= 0
+ then error('cache.clear(): chunk_size has to be a positive integer') end
+
+ -- Do the C call, and add chunk_size warning.
+ rettable.count = ffi.C.kr_cache_remove_subtree(cach, dname, exact_name, chunk_size)
+ if rettable.count == chunk_size then
+ local msg_extra = ''
+ if callback == nil then
+ msg_extra = '; the default callback will continue asynchronously'
+ end
+ rettable.chunk_limit = 'chunk size limit reached' .. msg_extra
+ end
+
+ -- Default callback function: repeat after 1ms
+ if callback == nil then callback =
+ function (cbname, cbexact_name, cbrr_type, cbchunk_size, cbself, cbprev_state, cbrettable)
+ if cbrettable.count < 0 then error(ffi.string(ffi.C.knot_strerror(cbrettable.count))) end
+ if cbprev_state == nil then cbprev_state = { round = 0 } end
+ if type(cbprev_state) ~= 'table'
+ then error('cache.clear() callback: incorrect prev_state passed') end
+ cbrettable.round = cbprev_state.round + 1
+ if (cbrettable.count == cbchunk_size) then
+ event.after(1, function ()
+ cache.clear(cbname, cbexact_name, cbrr_type, cbchunk_size, cbself, cbrettable)
+ end)
+ elseif cbrettable.round > 1 then
+ log('[cache] asynchonous cache.clear(\'' .. cbname .. '\', '
+ .. tostring(cbexact_name) .. ') finished')
+ end
+ return cbrettable
+ end
+ end
+ return callback(name, exact_name, rr_type, chunk_size, callback, prev_state, rettable)
+end
+-- Syntactic sugar for cache
+-- `cache[x] -> cache.get(x)`
+-- `cache.{size|storage} = value`
+setmetatable(cache, {
+ __index = function (t, k)
+ local res = rawget(t, k)
+ if not res and not rawget(t, 'current_size') then return res end
+ -- Beware: t.get returns empty table on failure to find.
+ -- That would be confusing here (breaking kresc), so return nil instead.
+ res = t.get(k)
+ if res and next(res) ~= nil then return res else return nil end
+ end,
+ __newindex = function (t,k,v)
+ -- Defaults
+ local storage = rawget(t, 'current_storage')
+ if not storage then storage = 'lmdb://' end
+ local size = rawget(t, 'current_size')
+ if not size then size = 10*MB end
+ -- Declarative interface for cache
+ if k == 'size' then t.open(v, storage)
+ elseif k == 'storage' then t.open(size, v) end
+ end
+})
+
+-- Make sandboxed environment
+local function make_sandbox(defined)
+ local __protected = { worker = true, env = true, modules = true, cache = true, net = true, trust_anchors = true }
+
+ -- Compute and export the list of top-level names (hidden otherwise)
+ local nl = ""
+ for n in pairs(defined) do
+ nl = nl .. n .. "\n"
+ end
+
+ return setmetatable({ __orig_name_list = nl }, {
+ __index = defined,
+ __newindex = function (_, k, v)
+ if __protected[k] then
+ for k2,v2 in pairs(v) do
+ defined[k][k2] = v2
+ end
+ else
+ defined[k] = v
+ end
+ end
+ })
+end
+
+-- Compatibility sandbox
+_G = make_sandbox(getfenv(0))
+setfenv(0, _G)
+
+-- Load default modules
+--trust_anchors = require('trust_anchors')
+--modules.load('ta_update')
+--modules.load('ta_signal_query')
+--modules.load('policy')
+--modules.load('priming')
+--modules.load('detect_time_skew')
+--modules.load('detect_time_jump')
+--modules.load('ta_sentinel')
+--modules.load('edns_keepalive')
+--modules.load('refuse_nord')
+--modules.load('watchdog')
+
+modules.load('prefill')
+prefill.config({
+ ['.'] = {
+ url = 'https://www.internic.net/domain/root.zone',
+ interval = 86400,
+ ca_file = '/etc/pki/tls/certs/ca-bundle.crt',
+ }
+})
+
+-- Load keyfile_default
+-- trust_anchors.add_file('@keyfile_default@', @unmanaged@)
+
+-- Interactive command evaluation
+function eval_cmd(line, raw)
+ -- Compatibility sandbox code loading
+ local function load_code(code)
+ if getfenv then -- Lua 5.1
+ return loadstring(code)
+ else -- Lua 5.2+
+ return load(code, nil, 't', _ENV)
+ end
+ end
+ local err, chunk
+ chunk, err = load_code(raw and 'return '..line or 'return table_print('..line..')')
+ if err then
+ chunk, err = load_code(line)
+ end
+ if not err then
+ return chunk()
+ else
+ error(err)
+ end
+end
+
+-- Pretty printing
+
+local function funcsign(f)
+-- thanks to AnandA777 from StackOverflow! Function funcsign is adapted version of
+-- https://stackoverflow.com/questions/51095022/inspect-function-signature-in-lua-5-1
+ assert(type(f) == 'function', "bad argument #1 to 'funcsign' (function expected)")
+ local debuginfo = debug.getinfo(f)
+ if debuginfo.what == 'C' then -- names N/A
+ return '(?)'
+ end
+
+ local func_args = {}
+ pcall(function()
+ local oldhook
+ local delay = 2
+ local function hook()
+ delay = delay - 1
+ if delay == 0 then -- call this only for the introspected function
+ -- stack depth 2 is the introspected function
+ for i = 1, debuginfo.nparams do
+ local k = debug.getlocal(2, i)
+ table.insert(func_args, k)
+ end
+ if debuginfo.isvararg then
+ table.insert(func_args, "...")
+ end
+ debug.sethook(oldhook)
+ error('aborting the call to introspected function')
+ end
+ end
+ oldhook = debug.sethook(hook, "c") -- invoke hook() on function call
+ f(unpack({})) -- huh?
+ end)
+ return "(" .. table.concat(func_args, ", ") .. ")"
+end
+
+function table_print (tt, indent, done)
+ done = done or {}
+ indent = indent or 0
+ local result = ""
+ -- Ordered for-iterator for tables with tostring-able keys.
+ local function ordered_iter(unordered_tt)
+ local keys = {}
+ for k in pairs(unordered_tt) do
+ table.insert(keys, k)
+ end
+ table.sort(keys, function (a, b) return tostring(a) < tostring(b) end)
+ local i = 0
+ return function()
+ i = i + 1
+ if keys[i] then
+ return keys[i], unordered_tt[keys[i]]
+ end
+ end
+ end
+ -- Convert to printable string (escape unprintable)
+ local function printable(value)
+ value = tostring(value)
+ local bytes = {}
+ for i = 1, #value do
+ local c = string.byte(value, i)
+ if c >= 0x20 and c < 0x7f then table.insert(bytes, string.char(c))
+ else table.insert(bytes, '\\'..tostring(c))
+ end
+ if i > 80 then table.insert(bytes, '...') break end
+ end
+ return table.concat(bytes)
+ end
+ if type(tt) == "table" then
+ for key, value in ordered_iter(tt) do
+ result = result .. string.rep (" ", indent)
+ if type (value) == "table" and not done [value] then
+ done [value] = true
+ result = result .. string.format("[%s] => {\n", printable (key))
+ result = result .. table_print (value, indent + 4, done)
+ result = result .. string.rep (" ", indent)
+ result = result .. "}\n"
+ elseif type (value) == "function" then
+ result = result .. string.format("[%s] => function %s%s: %s\n",
+ tostring(key), tostring(key), funcsign(value),
+ string.sub(tostring(value), 11))
+ else
+ result = result .. string.format("[%s] => %s\n",
+ tostring (key), printable(value))
+ end
+ end
+ else -- not a table
+ local tt_str
+ if type(tt) == "function" then
+ tt_str = string.format("function%s: %s\n", funcsign(tt),
+ string.sub(tostring(tt), 11))
+ else
+ tt_str = tostring(tt)
+ end
+ result = result .. tt_str .. "\n"
+ end
+ return result
+end
+
+-- This extends the worker module to allow asynchronous execution of functions and nonblocking I/O.
+-- The current implementation combines cqueues for Lua interface, and event.socket() in order to not
+-- block resolver engine while waiting for I/O or timers.
+--
+local has_cqueues, cqueues = pcall(require, 'cqueues')
+if has_cqueues then
+
+ -- Export the asynchronous sleep function
+ worker.sleep = cqueues.sleep
+
+ -- Create metatable for workers to define the API
+ -- It can schedule multiple cqueues and yield execution when there's a wait for blocking I/O or timer
+ local asynchronous_worker_mt = {
+ work = function (self)
+ local ok, err, _, co = self.cq:step(0)
+ if not ok then
+ warn('[%s] error: %s %s', self.name or 'worker', err, debug.traceback(co))
+ end
+ -- Reschedule timeout or create new one
+ local timeout = self.cq:timeout()
+ if timeout then
+ -- Throttle timeouts to avoid too frequent wakeups
+ if timeout == 0 then timeout = 0.00001 end
+ -- Convert from seconds to duration
+ timeout = timeout * sec
+ if not self.next_timeout then
+ self.next_timeout = event.after(timeout, self.on_step)
+ else
+ event.reschedule(self.next_timeout, timeout)
+ end
+ else -- Cancel running timeout when there is no next deadline
+ if self.next_timeout then
+ event.cancel(self.next_timeout)
+ self.next_timeout = nil
+ end
+ end
+ end,
+ wrap = function (self, f)
+ self.cq:wrap(f)
+ end,
+ loop = function (self)
+ self.on_step = function () self:work() end
+ self.event_fd = event.socket(self.cq:pollfd(), self.on_step)
+ end,
+ close = function (self)
+ if self.event_fd then
+ event.cancel(self.event_fd)
+ self.event_fd = nil
+ end
+ end,
+ }
+
+ -- Implement the coroutine worker with cqueues
+ local function worker_new (name)
+ return setmetatable({name = name, cq = cqueues.new()}, { __index = asynchronous_worker_mt })
+ end
+
+ -- Create a default background worker
+ worker.bg_worker = worker_new('worker.background')
+ worker.bg_worker:loop()
+
+ -- Wrap a function for asynchronous execution
+ function worker.coroutine (f)
+ worker.bg_worker:wrap(f)
+ end
+else
+ -- Disable asynchronous execution
+ local function disabled () error('cqueues are required for asynchronous execution') end
+ worker.sleep = disabled
+ worker.map = disabled
+ worker.coroutine = disabled
+end
--- /dev/null
+/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "kresconfig.h"
+
+#include "contrib/ccan/asprintf/asprintf.h"
+#include "contrib/cleanup.h"
+#include "contrib/ucw/mempool.h"
+#include "engine.h"
+#include "io.h"
+#include "network.h"
+#include "tls.h"
+#include "udp_queue.h"
+#include "worker.h"
+#include "lib/defines.h"
+#include "lib/dnssec.h"
+#include "lib/dnssec/ta.h"
+#include "lib/resolve.h"
+
+#include <arpa/inet.h>
+#include <assert.h>
+#include <getopt.h>
+#include <libgen.h>
+#include <signal.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/resource.h>
+#include <unistd.h>
+
+#include "watcher.h"
+#include "modules/sysrepo/common/sysrepo.h"
+#include "sr_subscriptions.h"
+
+#ifdef ENABLE_CAP_NG
+#include <cap-ng.h>
+#endif
+
+#include <lua.h>
+#include <uv.h>
+#if SYSTEMD_VERSION > 0
+#include <systemd/sd-daemon.h>
+#endif
+#include <libknot/error.h>
+
+
+/* @internal Array of ip address shorthand. */
+typedef array_t(char*) addr_array_t;
+
+typedef array_t(const char*) config_array_t;
+
+typedef struct {
+ int fd;
+ endpoint_flags_t flags; /**< .sock_type isn't meaningful here */
+} flagged_fd_t;
+typedef array_t(flagged_fd_t) flagged_fd_array_t;
+
+struct args {
+ addr_array_t addrs, addrs_tls;
+ flagged_fd_array_t fds;
+ int control_fd;
+ int forks;
+ config_array_t config;
+ const char *rundir;
+ bool interactive;
+ bool quiet;
+ bool tty_binary_output;
+};
+
+/**
+ * TTY control: process input and free() the buffer.
+ *
+ * For parameters see http://docs.libuv.org/en/v1.x/stream.html#c.uv_read_cb
+ *
+ * - This is just basic read-eval-print; libedit is supported through kresc;
+ * - stream->data contains program arguments (struct args);
+ */
+static void tty_process_input(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf)
+{
+ char *cmd = buf ? buf->base : NULL; /* To be free()d on return. */
+
+ /* Set output streams */
+ FILE *out = stdout;
+ uv_os_fd_t stream_fd = 0;
+ struct args *args = stream->data;
+ if (uv_fileno((uv_handle_t *)stream, &stream_fd)) {
+ uv_close((uv_handle_t *)stream, (uv_close_cb) free);
+ free(cmd);
+ return;
+ }
+ if (stream_fd != STDIN_FILENO) {
+ if (nread < 0) { /* Close if disconnected */
+ uv_close((uv_handle_t *)stream, (uv_close_cb) free);
+ }
+ if (nread <= 0) {
+ free(cmd);
+ return;
+ }
+ uv_os_fd_t dup_fd = dup(stream_fd);
+ if (dup_fd >= 0) {
+ out = fdopen(dup_fd, "w");
+ }
+ }
+
+ /* Execute */
+ if (stream && cmd && nread > 0) {
+ /* Ensure cmd is 0-terminated */
+ if (cmd[nread - 1] == '\n') {
+ cmd[nread - 1] = '\0';
+ } else {
+ if (nread >= buf->len) { /* only equality should be possible */
+ char *newbuf = realloc(cmd, nread + 1);
+ if (!newbuf)
+ goto finish;
+ cmd = newbuf;
+ }
+ cmd[nread] = '\0';
+ }
+
+ /* Pseudo-command for switching to "binary output"; */
+ if (strcmp(cmd, "__binary") == 0) {
+ args->tty_binary_output = true;
+ goto finish;
+ }
+
+ lua_State *L = the_worker->engine->L;
+ int ret = engine_cmd(L, cmd, false);
+ const char *message = "";
+ if (lua_gettop(L) > 0) {
+ message = lua_tostring(L, -1);
+ }
+
+ /* Simpler output in binary mode */
+ if (args->tty_binary_output) {
+ size_t len_s = strlen(message);
+ if (len_s > UINT32_MAX)
+ goto finish;
+ uint32_t len_n = htonl(len_s);
+ fwrite(&len_n, sizeof(len_n), 1, out);
+ fwrite(message, len_s, 1, out);
+ lua_settop(L, 0);
+ goto finish;
+ }
+
+ /* Log to remote socket if connected */
+ const char *delim = args->quiet ? "" : "> ";
+ if (stream_fd != STDIN_FILENO) {
+ if (VERBOSE_STATUS)
+ fprintf(stdout, "%s\n", cmd); /* Duplicate command to logs */
+ if (message)
+ fprintf(out, "%s", message); /* Duplicate output to sender */
+ if (message || !args->quiet)
+ fprintf(out, "\n");
+ fprintf(out, "%s", delim);
+ }
+ if (stream_fd == STDIN_FILENO || VERBOSE_STATUS) {
+ /* Log to standard streams */
+ FILE *fp_out = ret ? stderr : stdout;
+ if (message)
+ fprintf(fp_out, "%s", message);
+ if (message || !args->quiet)
+ fprintf(fp_out, "\n");
+ fprintf(fp_out, "%s", delim);
+ }
+ lua_settop(L, 0);
+ }
+finish:
+ free(cmd);
+ /* Close if redirected */
+ if (stream_fd != STDIN_FILENO) {
+ fclose(out);
+ }
+}
+
+static void tty_alloc(uv_handle_t *handle, size_t suggested, uv_buf_t *buf) {
+ buf->len = suggested;
+ buf->base = malloc(suggested);
+}
+
+static void tty_accept(uv_stream_t *master, int status)
+{
+ uv_tcp_t *client = malloc(sizeof(*client));
+ struct args *args = master->data;
+ if (client) {
+ uv_tcp_init(master->loop, client);
+ if (uv_accept(master, (uv_stream_t *)client) != 0) {
+ free(client);
+ return;
+ }
+ client->data = args;
+ uv_read_start((uv_stream_t *)client, tty_alloc, tty_process_input);
+ /* Write command line */
+ if (!args->quiet) {
+ uv_buf_t buf = { "> ", 2 };
+ uv_try_write((uv_stream_t *)client, &buf, 1);
+ }
+ }
+}
+
+/* @internal AF_LOCAL reads may still be interrupted, loop it. */
+static bool ipc_readall(int fd, char *dst, size_t len)
+{
+ while (len > 0) {
+ int rb = read(fd, dst, len);
+ if (rb > 0) {
+ dst += rb;
+ len -= rb;
+ } else if (errno != EAGAIN && errno != EINTR) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static void ipc_activity(uv_poll_t *handle, int status, int events)
+{
+ struct engine *engine = handle->data;
+ if (status != 0) {
+ kr_log_error("[system] ipc: %s\n", uv_strerror(status));
+ return;
+ }
+ /* Get file descriptor from handle */
+ uv_os_fd_t fd = 0;
+ (void) uv_fileno((uv_handle_t *)(handle), &fd);
+ /* Read expression from IPC pipe */
+ uint32_t len = 0;
+ auto_free char *rbuf = NULL;
+ if (!ipc_readall(fd, (char *)&len, sizeof(len))) {
+ goto failure;
+ }
+ if (len < UINT32_MAX) {
+ rbuf = malloc(len + 1);
+ } else {
+ errno = EINVAL;
+ }
+ if (!rbuf) {
+ goto failure;
+ }
+ if (!ipc_readall(fd, rbuf, len)) {
+ goto failure;
+ }
+ rbuf[len] = '\0';
+ /* Run expression */
+ const char *message = "";
+ int ret = engine_ipc(engine, rbuf);
+ if (ret > 0) {
+ message = lua_tostring(engine->L, -1);
+ }
+ /* Clear the Lua stack */
+ lua_settop(engine->L, 0);
+ /* Send response back */
+ len = strlen(message);
+ if (write(fd, &len, sizeof(len)) != sizeof(len) ||
+ write(fd, message, len) != len) {
+ goto failure;
+ }
+ return; /* success! */
+failure:
+ /* Note that if the piped command got read or written partially,
+ * we would get out of sync and only receive rubbish now.
+ * Therefore we prefer to stop IPC, but we try to continue with all else.
+ */
+ kr_log_error("[system] stopping ipc because of: %s\n", strerror(errno));
+ uv_poll_stop(handle);
+ uv_close((uv_handle_t *)handle, (uv_close_cb)free);
+}
+
+static bool ipc_watch(uv_loop_t *loop, struct engine *engine, int fd)
+{
+ uv_poll_t *poller = malloc(sizeof(*poller));
+ if (!poller) {
+ return false;
+ }
+ int ret = uv_poll_init(loop, poller, fd);
+ if (ret != 0) {
+ free(poller);
+ return false;
+ }
+ poller->data = engine;
+ ret = uv_poll_start(poller, UV_READABLE, ipc_activity);
+ if (ret != 0) {
+ free(poller);
+ return false;
+ }
+ /* libuv sets O_NONBLOCK whether we want it or not */
+ (void) fcntl(fd, F_SETFD, fcntl(fd, F_GETFL) & ~O_NONBLOCK);
+ return true;
+}
+
+static void signal_handler(uv_signal_t *handle, int signum)
+{
+ uv_stop(uv_default_loop());
+ uv_signal_stop(handle);
+}
+
+/** SIGBUS -> attempt to remove the overflowing cache file and abort. */
+static void sigbus_handler(int sig, siginfo_t *siginfo, void *ptr)
+{
+ /* We can't safely assume that printf-like functions work, but write() is OK.
+ * See POSIX for the safe functions, e.g. 2017 version just above this link:
+ * http://pubs.opengroup.org/onlinepubs/9699919799/functions/V2_chap02.html#tag_15_04_04
+ */
+ #define WRITE_ERR(err_charray) \
+ (void)write(STDERR_FILENO, err_charray, sizeof(err_charray))
+ /* Unfortunately, void-cast on the write isn't enough to avoid the warning. */
+ #pragma GCC diagnostic push
+ #pragma GCC diagnostic ignored "-Wunused-result"
+ const char msg_typical[] =
+ "\nSIGBUS received; this is most likely due to filling up the filesystem where cache resides.\n",
+ msg_unknown[] = "\nSIGBUS received, cause unknown.\n",
+ msg_deleted[] = "Cache file deleted.\n",
+ msg_del_fail[] = "Cache file deletion failed.\n",
+ msg_final[] = "kresd can not recover reliably by itself, exiting.\n";
+ if (siginfo->si_code != BUS_ADRERR) {
+ WRITE_ERR(msg_unknown);
+ goto end;
+ }
+ WRITE_ERR(msg_typical);
+ if (!kr_cache_emergency_file_to_remove) goto end;
+ if (unlink(kr_cache_emergency_file_to_remove)) {
+ WRITE_ERR(msg_del_fail);
+ } else {
+ WRITE_ERR(msg_deleted);
+ }
+end:
+ WRITE_ERR(msg_final);
+ _exit(128 - sig); /*< regular return from OS-raised SIGBUS can't work anyway */
+ #undef WRITE_ERR
+ #pragma GCC diagnostic pop
+}
+
+
+/*
+ * Server operation.
+ */
+
+static int fork_workers(fd_array_t *ipc_set, int forks)
+{
+ /* Fork subprocesses if requested */
+ while (--forks > 0) {
+ int sv[2] = {-1, -1};
+ if (socketpair(AF_LOCAL, SOCK_STREAM, 0, sv) < 0) {
+ perror("[system] socketpair");
+ return kr_error(errno);
+ }
+ int pid = fork();
+ if (pid < 0) {
+ perror("[system] fork");
+ return kr_error(errno);
+ }
+
+ /* Forked process */
+ if (pid == 0) {
+ array_clear(*ipc_set);
+ array_push(*ipc_set, sv[0]);
+ close(sv[1]);
+ return forks;
+ /* Parent process */
+ } else {
+ array_push(*ipc_set, sv[1]);
+ /* Do not share parent-end with other forks. */
+ (void) fcntl(sv[1], F_SETFD, FD_CLOEXEC);
+ close(sv[0]);
+ }
+ }
+ return 0;
+}
+
+static void help(int argc, char *argv[])
+{
+ printf("Usage: %s [parameters] [rundir]\n", argv[0]);
+ printf("\nParameters:\n"
+ " -c, --config=[path] Config file path (relative to [rundir]) (default: config).\n"
+ " -q, --quiet No command prompt in interactive mode.\n"
+ " -v, --verbose Run in verbose mode."
+#ifdef NOVERBOSELOG
+ " (Recompile without -DNOVERBOSELOG to activate.)"
+#endif
+ "\n"
+ " -V, --version Print version of the server.\n"
+ " -h, --help Print help and usage.\n"
+ "Options:\n"
+ " [rundir] Path to the working directory (default: .)\n");
+}
+
+/** \return exit code for main() */
+static int run_worker(uv_loop_t *loop, struct engine *engine, fd_array_t *ipc_set, bool leader, struct args *args)
+{
+ /* Start sysrepo context */
+ sysrepo_uv_ctx_t *sysrepo = engine->watcher.sysrepo;
+ int ret = 0;
+ if (!ret) ret = sysrepo_subscr_register(sysrepo->session, &sysrepo->subscription);
+ if (!ret) ret = sysrepo_ctx_start(loop, sysrepo);
+ if (ret) {
+ kr_log_error("[sysrepo] failed to subscribe for changes: %s\n", sr_strerror(ret));
+ }
+
+ /* Start Processes if Enabled */
+ //resolver_start();
+
+ /* Only some kinds of stdin work with uv_pipe_t.
+ * Otherwise we would abort() from libuv e.g. with </dev/null */
+ if (args->interactive) switch (uv_guess_handle(0)) {
+ case UV_TTY: /* standard terminal */
+ /* TODO: it has worked OK so far, but we'd better use uv_tty_*
+ * for this case instead of uv_pipe_*. */
+ case UV_NAMED_PIPE: /* echo 'quit()' | kresd ... */
+ break;
+ default:
+ kr_log_error(
+ "[system] error: standard input is not a terminal or pipe; "
+ "use '-f 1' if you want non-interactive mode. "
+ "Commands can be simply added to your configuration file or sent over the tty/$PID control socket.\n"
+ );
+ return EXIT_FAILURE;
+ }
+
+ if (setvbuf(stdout, NULL, _IONBF, 0) || setvbuf(stderr, NULL, _IONBF, 0)) {
+ kr_log_error("[system] failed to to set output buffering (ignored): %s\n",
+ strerror(errno));
+ fflush(stderr);
+ }
+
+ /* Control sockets or TTY */
+ auto_free char *sock_file = NULL;
+ uv_pipe_t pipe;
+ uv_pipe_init(loop, &pipe, 0);
+ pipe.data = args;
+ if (args->interactive) {
+ if (!args->quiet)
+ printf("[system] interactive mode\n> ");
+ uv_pipe_open(&pipe, 0);
+ uv_read_start((uv_stream_t*) &pipe, tty_alloc, tty_process_input);
+ } else {
+ int pipe_ret = -1;
+ if (args->control_fd != -1) {
+ pipe_ret = uv_pipe_open(&pipe, args->control_fd);
+ } else {
+ (void) mkdir("tty", S_IRWXU|S_IRWXG);
+ sock_file = afmt("tty/%ld", (long)getpid());
+ if (sock_file) {
+ pipe_ret = uv_pipe_bind(&pipe, sock_file);
+ }
+ }
+ if (!pipe_ret)
+ uv_listen((uv_stream_t *) &pipe, 16, tty_accept);
+ }
+ /* Watch IPC pipes (or just assign them if leading the pgroup). */
+ if (!leader) {
+ for (size_t i = 0; i < ipc_set->len; ++i) {
+ if (!ipc_watch(loop, engine, ipc_set->at[i])) {
+ kr_log_error("[system] failed to create poller: %s\n", strerror(errno));
+ close(ipc_set->at[i]);
+ }
+ }
+ }
+ memcpy(&engine->ipc_set, ipc_set, sizeof(*ipc_set));
+
+ /* Notify supervisor. */
+#if SYSTEMD_VERSION > 0
+ sd_notify(0, "READY=1");
+#endif
+ /* Run event loop */
+ uv_run(loop, UV_RUN_DEFAULT);
+ if (sock_file) {
+ unlink(sock_file);
+ }
+ uv_close((uv_handle_t *)&pipe, NULL); /* Seems OK even on the stopped loop. */
+ return EXIT_SUCCESS;
+}
+
+#if SYSTEMD_VERSION >= 227
+static void free_sd_socket_names(char **socket_names, int count)
+{
+ for (int i = 0; i < count; i++) {
+ free(socket_names[i]);
+ }
+ free(socket_names);
+}
+#endif
+
+static void args_init(struct args *args)
+{
+ memset(args, 0, sizeof(struct args));
+ /* Zeroed arrays are OK. */
+ args->forks = 1;
+ args->control_fd = -1;
+ args->interactive = true;
+ args->quiet = false;
+}
+
+/* Free pointed-to resources. */
+static void args_deinit(struct args *args)
+{
+ array_clear(args->addrs);
+ array_clear(args->addrs_tls);
+ for (int i = 0; i < args->fds.len; ++i)
+ free_const(args->fds.at[i].flags.kind);
+ array_clear(args->fds);
+ array_clear(args->config);
+}
+
+static long strtol_10(const char *s)
+{
+ if (!s) abort();
+ /* ^^ This shouldn't ever happen. When getopt_long() returns an option
+ * character that has a mandatory parameter, optarg can't be NULL. */
+ return strtol(s, NULL, 10);
+}
+
+/** Process arguments into struct args.
+ * @return >=0 if main() should be exited immediately.
+ */
+static int parse_args(int argc, char **argv, struct args *args)
+{
+ /* Long options. */
+ int c = 0, li = 0;
+ struct option opts[] = {
+ {"addr", required_argument, 0, 'a'},
+ {"tls", required_argument, 0, 't'},
+ {"fd", required_argument, 0, 'S'},
+ {"config", required_argument, 0, 'c'},
+ {"forks", required_argument, 0, 'f'},
+ {"verbose", no_argument, 0, 'v'},
+ {"quiet", no_argument, 0, 'q'},
+ {"version", no_argument, 0, 'V'},
+ {"help", no_argument, 0, 'h'},
+ {0, 0, 0, 0}
+ };
+ while ((c = getopt_long(argc, argv, "a:t:S:c:f:m:K:k:vqVh", opts, &li)) != -1) {
+ switch (c)
+ {
+ case 'a':
+ array_push(args->addrs, optarg);
+ break;
+ case 't':
+ array_push(args->addrs_tls, optarg);
+ break;
+ case 'c':
+ assert(optarg != NULL);
+ array_push(args->config, optarg);
+ break;
+ case 'f':
+ args->interactive = false;
+ args->forks = strtol_10(optarg);
+ if (args->forks <= 0) {
+ kr_log_error("[system] error '-f' requires a positive"
+ " number, not '%s'\n", optarg);
+ return EXIT_FAILURE;
+ }
+ break;
+ case 'v':
+ kr_verbose_set(true);
+#ifdef NOVERBOSELOG
+ kr_log_info("--verbose flag has no effect due to compilation with -DNOVERBOSELOG.\n");
+#endif
+ break;
+ case 'q':
+ args->quiet = true;
+ break;
+ case 'V':
+ kr_log_info("%s, version %s\n", "Knot Resolver", PACKAGE_VERSION);
+ return EXIT_SUCCESS;
+ case 'h':
+ case '?':
+ help(argc, argv);
+ return EXIT_SUCCESS;
+ default:
+ help(argc, argv);
+ return EXIT_FAILURE;
+ case 'S':
+ (void)0;
+ flagged_fd_t ffd = { 0 };
+ char *endptr;
+ ffd.fd = strtol(optarg, &endptr, 10);
+ if (endptr != optarg && endptr[0] == '\0') {
+ /* Plain DNS */
+ ffd.flags.tls = false;
+ } else if (endptr[0] == ':' && strcasecmp(endptr + 1, "tls") == 0) {
+ /* DoT */
+ ffd.flags.tls = true;
+ /* We know what .sock_type should be but it wouldn't help. */
+ } else if (endptr[0] == ':' && endptr[1] != '\0') {
+ /* Some other kind; no checks here. */
+ ffd.flags.kind = strdup(endptr + 1);
+ } else {
+ kr_log_error("[system] incorrect value passed to '-S/--fd': %s\n",
+ optarg);
+ return EXIT_FAILURE;
+ }
+ array_push(args->fds, ffd);
+ break;
+ }
+ }
+ if (optind < argc) {
+ args->rundir = argv[optind];
+ }
+ return -1;
+}
+
+/** Just convert addresses to file-descriptors; clear *addrs on success.
+ * @note AF_UNIX is supported (starting with '/').
+ * @return zero or exit code for main()
+ */
+static int bind_sockets(addr_array_t *addrs, bool tls, flagged_fd_array_t *fds)
+{
+ bool has_error = false;
+ for (size_t i = 0; i < addrs->len; ++i) {
+ /* Get port and separate address string. */
+ uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT;
+ char addr_buf[INET6_ADDRSTRLEN + 1];
+ int ret;
+ const char *addr_str;
+ const int family = kr_straddr_family(addrs->at[i]);
+ if (family == AF_UNIX) {
+ ret = 0;
+ addr_str = addrs->at[i];
+ } else { /* internet socket (or garbage) */
+ ret = kr_straddr_split(addrs->at[i], addr_buf, &port);
+ addr_str = addr_buf;
+ }
+ /* Get sockaddr. */
+ struct sockaddr *sa = NULL;
+ if (ret == 0) {
+ sa = kr_straddr_socket(addr_str, port, NULL);
+ if (!sa) ret = kr_error(EINVAL); /* could be ENOMEM but unlikely */
+ }
+ flagged_fd_t ffd = { .flags = { .tls = tls } };
+ if (ret == 0 && !tls && family != AF_UNIX) {
+ /* AF_UNIX can do SOCK_DGRAM, but let's not support that *here*. */
+ ffd.fd = io_bind(sa, SOCK_DGRAM, NULL);
+ if (ffd.fd < 0)
+ ret = ffd.fd;
+ else if (array_push(*fds, ffd) < 0)
+ ret = kr_error(ENOMEM);
+ }
+ if (ret == 0) { /* common for TCP and TLS, including AF_UNIX cases */
+ ffd.fd = io_bind(sa, SOCK_STREAM, NULL);
+ if (ffd.fd < 0)
+ ret = ffd.fd;
+ else if (array_push(*fds, ffd) < 0)
+ ret = kr_error(ENOMEM);
+ }
+ free(sa);
+ if (ret != 0) {
+ kr_log_error("[system] bind to '%s'%s: %s\n",
+ addrs->at[i], tls ? " (TLS)" : "", kr_strerror(ret));
+ has_error = true;
+ }
+ }
+ array_clear(*addrs);
+ return has_error ? EXIT_FAILURE : kr_ok();
+}
+
+static int start_listening(struct network *net, flagged_fd_array_t *fds) {
+ int some_bad_ret = 0;
+ for (size_t i = 0; i < fds->len; ++i) {
+ flagged_fd_t *ffd = &fds->at[i];
+ int ret = network_listen_fd(net, ffd->fd, ffd->flags);
+ if (ret != 0) {
+ some_bad_ret = ret;
+ /* TODO: try logging address@port. It's not too important,
+ * because typical problems happen during binding already.
+ * (invalid address, permission denied) */
+ kr_log_error("[system] listen on fd=%d: %s\n",
+ ffd->fd, kr_strerror(ret));
+ /* Continue printing all of these before exiting. */
+ } else {
+ ffd->flags.kind = NULL; /* ownership transferred */
+ }
+ }
+ return some_bad_ret;
+}
+
+/* Drop POSIX 1003.1e capabilities. */
+// static void drop_capabilities(void)
+// {
+// #ifdef ENABLE_CAP_NG
+// /* Drop all capabilities. */
+// if (capng_have_capability(CAPNG_EFFECTIVE, CAP_SETPCAP)) {
+// capng_clear(CAPNG_SELECT_BOTH);
+
+// /* Apply. */
+// if (capng_apply(CAPNG_SELECT_BOTH) < 0) {
+// kr_log_error("[system] failed to set process capabilities: %s\n",
+// strerror(errno));
+// }
+// } else {
+// kr_log_info("[system] process not allowed to set capabilities, skipping\n");
+// }
+// #endif /* ENABLE_CAP_NG */
+// }
+
+int main(int argc, char **argv)
+{
+ struct args args;
+ args_init(&args);
+ int ret = parse_args(argc, argv, &args);
+ if (ret >= 0) goto cleanup_args;
+
+ ret = bind_sockets(&args.addrs, false, &args.fds);
+ if (ret) goto cleanup_args;
+ ret = bind_sockets(&args.addrs_tls, true, &args.fds);
+ if (ret) goto cleanup_args;
+
+#if SYSTEMD_VERSION >= 227
+ /* Accept passed sockets from systemd supervisor. */
+ char **socket_names = NULL;
+ int sd_nsocks = sd_listen_fds_with_names(0, &socket_names);
+ if (sd_nsocks < 0) {
+ kr_log_error("[system] failed passing sockets from systemd: %s\n",
+ kr_strerror(sd_nsocks));
+ free_sd_socket_names(socket_names, sd_nsocks);
+ ret = EXIT_FAILURE;
+ goto cleanup_args;
+ }
+ if (sd_nsocks > 0 && args.forks != 1) {
+ kr_log_error("[system] when run under systemd-style supervision, "
+ "use single-process only (bad: --forks=%d).\n", args.forks);
+ free_sd_socket_names(socket_names, sd_nsocks);
+ ret = EXIT_FAILURE;
+ goto cleanup_args;
+ }
+ for (int i = 0; i < sd_nsocks; ++i) {
+ /* when run under systemd supervision, do not use interactive mode */
+ args.interactive = false;
+ flagged_fd_t ffd = { .fd = SD_LISTEN_FDS_START + i };
+
+ if (!strcasecmp("control", socket_names[i])) {
+ if (args.control_fd != -1) {
+ kr_log_error("[system] multiple control sockets passed from systemd\n");
+ ret = EXIT_FAILURE;
+ break;
+ }
+ args.control_fd = ffd.fd;
+ free(socket_names[i]);
+ } else {
+ if (!strcasecmp("dns", socket_names[i])) {
+ free(socket_names[i]);
+ } else if (!strcasecmp("tls", socket_names[i])) {
+ ffd.flags.tls = true;
+ free(socket_names[i]);
+ } else {
+ ffd.flags.kind = socket_names[i];
+ }
+ array_push(args.fds, ffd);
+ }
+ /* Either freed or passed ownership. */
+ socket_names[i] = NULL;
+ }
+ free_sd_socket_names(socket_names, sd_nsocks);
+ if (ret) goto cleanup_args;
+#endif
+
+ /* Switch to rundir. */
+ if (args.rundir != NULL) {
+ /* FIXME: access isn't a good way if we start as root and drop privileges later */
+ if (access(args.rundir, W_OK) != 0
+ || chdir(args.rundir) != 0) {
+ kr_log_error("[system] rundir '%s': %s\n",
+ args.rundir, strerror(errno));
+ return EXIT_FAILURE;
+ }
+ }
+
+ /* Select which config files to load and verify they are read-able. */
+ bool load_defaults = true;
+ size_t i = 0;
+ while (i < args.config.len) {
+ const char *config = args.config.at[i];
+ if (strcmp(config, "-") == 0) {
+ load_defaults = false;
+ array_del(args.config, i);
+ continue; /* don't increment i */
+ } else if (access(config, R_OK) != 0) {
+ char cwd[PATH_MAX];
+ get_workdir(cwd, sizeof(cwd));
+ kr_log_error("[system] config '%s' (workdir '%s'): %s\n",
+ config, cwd, strerror(errno));
+ return EXIT_FAILURE;
+ }
+ i++;
+ }
+ if (args.config.len == 0 && access("config", R_OK) == 0)
+ array_push(args.config, "config");
+ if (load_defaults)
+ array_push(args.config, LIBDIR "/config-watcher.lua");
+
+ /* File-descriptor count limit: soft->hard. */
+ struct rlimit rlim;
+ ret = getrlimit(RLIMIT_NOFILE, &rlim);
+ if (ret == 0 && rlim.rlim_cur != rlim.rlim_max) {
+ kr_log_verbose("[system] increasing file-descriptor limit: %ld -> %ld\n",
+ (long)rlim.rlim_cur, (long)rlim.rlim_max);
+ rlim.rlim_cur = rlim.rlim_max;
+ ret = setrlimit(RLIMIT_NOFILE, &rlim);
+ }
+ if (ret) {
+ kr_log_error("[system] failed to get or set file-descriptor limit: %s\n",
+ strerror(errno));
+ } else if (rlim.rlim_cur < 512*1024) {
+ kr_log_info("[system] warning: hard limit for number of file-descriptors is only %ld but recommended value is 524288\n",
+ rlim.rlim_cur);
+ }
+
+ /* Connect forks with local socket */
+ fd_array_t ipc_set;
+ array_init(ipc_set);
+ /* Fork subprocesses if requested */
+ int fork_id = fork_workers(&ipc_set, args.forks);
+ if (fork_id < 0) {
+ return EXIT_FAILURE;
+ }
+
+ //kr_crypto_init();
+
+ /* Create a server engine. */
+ knot_mm_t pool = {
+ .ctx = mp_new (4096),
+ .alloc = (knot_mm_alloc_t) mp_alloc
+ };
+ /** Static to work around lua_pushlightuserdata() limitations.
+ * TODO: convert to a proper singleton like worker, most likely. */
+ static struct engine engine;
+ ret = engine_init(&engine, &pool);
+ if (ret != 0) {
+ kr_log_error("[system] failed to initialize engine: %s\n", kr_strerror(ret));
+ return EXIT_FAILURE;
+ }
+ /* Initialize the worker */
+ ret = worker_init(&engine, fork_id, args.forks);
+ if (ret != 0) {
+ kr_log_error("[system] failed to initialize worker: %s\n", kr_strerror(ret));
+ return EXIT_FAILURE;
+ }
+
+ uv_loop_t *loop = uv_default_loop();
+ /* Catch some signals. */
+ uv_signal_t sigint, sigterm;
+ if (true) ret = uv_signal_init(loop, &sigint);
+ if (!ret) ret = uv_signal_init(loop, &sigterm);
+ if (!ret) ret = uv_signal_start(&sigint, signal_handler, SIGINT);
+ if (!ret) ret = uv_signal_start(&sigterm, signal_handler, SIGTERM);
+ /* Block SIGPIPE; see https://github.com/libuv/libuv/issues/45 */
+ if (!ret && signal(SIGPIPE, SIG_IGN) == SIG_ERR) ret = errno;
+ if (!ret) {
+ /* Catching SIGBUS via uv_signal_* can't work; see:
+ * https://github.com/libuv/libuv/pull/1987 */
+ struct sigaction sa;
+ memset(&sa, 0, sizeof(sa));
+ sa.sa_sigaction = sigbus_handler;
+ sa.sa_flags = SA_SIGINFO;
+ if (sigaction(SIGBUS, &sa, NULL)) {
+ ret = errno;
+ }
+ }
+ if (ret) {
+ kr_log_error("[system] failed to set up signal handlers: %s\n",
+ strerror(abs(errno)));
+ ret = EXIT_FAILURE;
+ goto cleanup;
+ }
+ /* Profiling: avoid SIGPROF waking up the event loop. Otherwise the profiles
+ * (of the usual type) may skew results, e.g. epoll_pwait() taking lots of time. */
+ ret = uv_loop_configure(loop, UV_LOOP_BLOCK_SIGNAL, SIGPROF);
+ if (ret) {
+ kr_log_info("[system] failed to block SIGPROF in event loop, ignoring: %s\n",
+ uv_strerror(ret));
+ }
+
+ /* Start listening, in the sense of network_listen_fd(). */
+ // if (start_listening(&engine.net, &args.fds) != 0) {
+ // ret = EXIT_FAILURE;
+ // goto cleanup;
+ // }
+
+ // ret = udp_queue_init_global(loop);
+ // if (ret) {
+ // kr_log_error("[system] failed to initialize UDP queue: %s\n",
+ // kr_strerror(ret));
+ // ret = EXIT_FAILURE;
+ // goto cleanup;
+ // }
+
+ /* Start the scripting engine */
+ if (engine_load_sandbox(&engine) != 0) {
+ ret = EXIT_FAILURE;
+ goto cleanup;
+ }
+
+ for (i = 0; i < args.config.len; ++i) {
+ const char *config = args.config.at[i];
+ if (engine_loadconf(&engine, config) != 0) {
+ ret = EXIT_FAILURE;
+ goto cleanup;
+ }
+ lua_settop(engine.L, 0);
+ }
+
+ //drop_capabilities();
+
+ if (engine_start(&engine) != 0) {
+ ret = EXIT_FAILURE;
+ goto cleanup;
+ }
+
+ // if (network_engage_endpoints(&engine.net)) {
+ // ret = EXIT_FAILURE;
+ // goto cleanup;
+ // }
+
+ /* Run the event loop */
+ ret = run_worker(loop, &engine, &ipc_set, fork_id == 0, &args);
+
+cleanup:/* Cleanup. */
+ engine_deinit(&engine);
+ worker_deinit();
+ if (loop != NULL) {
+ uv_loop_close(loop);
+ }
+ mp_delete(pool.ctx);
+cleanup_args:
+ args_deinit(&args);
+ //kr_crypto_cleanup();
+ return ret;
+}
--- /dev/null
+# kres-watcher
+
+watcher_src = files([
+ 'bindings/cache.c',
+ 'bindings/event.c',
+ 'bindings/impl.c',
+ 'bindings/modules.c',
+ 'bindings/net.c',
+ 'bindings/worker.c',
+ 'engine.c',
+ 'ffimodule.c',
+ 'io.c',
+ 'main.c',
+ 'network.c',
+ 'session.c',
+ 'tls.c',
+ 'tls_ephemeral_credentials.c',
+ 'tls_session_ticket-srv.c',
+ 'udp_queue.c',
+ 'watcher.c',
+ 'worker.c',
+ 'zimport.c',
+ 'sr_subscriptions.c',
+ 'dbus_control.c'
+])
+c_src_lint += watcher_src
+
+watcher_deps = [
+ contrib_dep,
+ kresconfig_dep,
+ libkres_dep,
+ libknot,
+ libzscanner,
+ libdnssec,
+ libuv,
+ luajit,
+ gnutls,
+ libsystemd,
+ capng,
+ libyang,
+ libsysrepo,
+]
+
+subdir('lua')
+
+if build_sysrepo
+ kres_watcher = executable(
+ 'kres-watcher',
+ watcher_src,
+ sysrepo_common_src,
+ dependencies: watcher_deps,
+ export_dynamic: true,
+ install: true,
+ install_dir: get_option('sbindir'),
+ )
+endif
--- /dev/null
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "network.h"
+
+#include "bindings/impl.h"
+#include "io.h"
+#include "tls.h"
+#include "worker.h"
+
+#include <assert.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
+{
+ if (net != NULL) {
+ net->loop = loop;
+ net->endpoints = map_make(NULL);
+ net->endpoint_kinds = trie_create(NULL);
+ net->tls_client_params = NULL;
+ net->tls_session_ticket_ctx = /* unsync. random, by default */
+ tls_session_ticket_ctx_create(loop, NULL, 0);
+ net->tcp.in_idle_timeout = 10000;
+ net->tcp.tls_handshake_timeout = TLS_MAX_HANDSHAKE_TIME;
+ net->tcp_backlog = tcp_backlog;
+ }
+}
+
+/** Notify the registered function about endpoint getting open.
+ * If log_port < 1, don't log it. */
+static int endpoint_open_lua_cb(struct network *net, struct endpoint *ep,
+ const char *log_addr)
+{
+ const bool ok = ep->flags.kind && !ep->handle && !ep->engaged && ep->fd != -1;
+ if (!ok) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ /* First find callback in the endpoint registry. */
+ lua_State *L = the_worker->engine->L;
+ void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind,
+ strlen(ep->flags.kind));
+ if (!pp && net->missing_kind_is_error) {
+ kr_log_error("warning: network socket kind '%s' not handled when opening '%s",
+ ep->flags.kind, log_addr);
+ if (ep->family != AF_UNIX)
+ kr_log_error("#%d", ep->port);
+ kr_log_error("'. Likely causes: typo or not loading 'http' module.\n");
+ /* No hard error, for now. LATER: perhaps differentiate between
+ * explicit net.listen() calls and "just unused" systemd sockets.
+ return kr_error(ENOENT);
+ */
+ }
+ if (!pp) return kr_ok();
+
+ /* Now execute the callback. */
+ const int fun_id = (char *)*pp - (char *)NULL;
+ lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
+ lua_pushboolean(L, true /* open */);
+ lua_pushpointer(L, ep);
+ if (ep->family == AF_UNIX) {
+ lua_pushstring(L, log_addr);
+ } else {
+ lua_pushfstring(L, "%s#%d", log_addr, ep->port);
+ }
+ if (lua_pcall(L, 3, 0, 0)) {
+ kr_log_error("error opening %s: %s\n", log_addr, lua_tostring(L, -1));
+ return kr_error(ENOSYS); /* TODO: better value? */
+ }
+ ep->engaged = true;
+ return kr_ok();
+}
+
+static int engage_endpoint_array(const char *key, void *endpoints, void *net)
+{
+ endpoint_array_t *eps = (endpoint_array_t *)endpoints;
+ for (int i = 0; i < eps->len; ++i) {
+ struct endpoint *ep = &eps->at[i];
+ const bool match = !ep->engaged && ep->flags.kind;
+ if (!match) continue;
+ int ret = endpoint_open_lua_cb(net, ep, key);
+ if (ret) return ret;
+ }
+ return 0;
+}
+int network_engage_endpoints(struct network *net)
+{
+ if (net->missing_kind_is_error)
+ return kr_ok(); /* maybe weird, but let's make it idempotent */
+ net->missing_kind_is_error = true;
+ int ret = map_walk(&net->endpoints, engage_endpoint_array, net);
+ if (ret) {
+ net->missing_kind_is_error = false; /* avoid the same errors when closing */
+ return ret;
+ }
+ return kr_ok();
+}
+
+
+/** Notify the registered function about endpoint about to be closed. */
+static void endpoint_close_lua_cb(struct network *net, struct endpoint *ep)
+{
+ lua_State *L = the_worker->engine->L;
+ void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind,
+ strlen(ep->flags.kind));
+ if (!pp && net->missing_kind_is_error) {
+ kr_log_error("internal error: missing kind '%s' in endpoint registry\n",
+ ep->flags.kind);
+ return;
+ }
+ if (!pp) return;
+
+ const int fun_id = (char *)*pp - (char *)NULL;
+ lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
+ lua_pushboolean(L, false /* close */);
+ lua_pushpointer(L, ep);
+ lua_pushstring(L, "FIXME:endpoint-identifier");
+ if (lua_pcall(L, 3, 0, 0)) {
+ kr_log_error("failed to close FIXME:endpoint-identifier: %s\n",
+ lua_tostring(L, -1));
+ }
+}
+
+static void endpoint_close(struct network *net, struct endpoint *ep, bool force)
+{
+ assert(!ep->handle != !ep->flags.kind);
+ if (ep->family == AF_UNIX) { /* The FS name would be left behind. */
+ /* Extract local address for this socket. */
+ struct sockaddr_un sa;
+ sa.sun_path[0] = '\0'; /*< probably only for lint:scan-build */
+ socklen_t addr_len = sizeof(sa);
+ if (getsockname(ep->fd, (struct sockaddr *)&sa, &addr_len)
+ || unlink(sa.sun_path)) {
+ kr_log_error("error (ignored) when closing unix socket (fd = %d): %s\n",
+ ep->fd, strerror(errno));
+ return;
+ }
+ }
+
+ if (ep->flags.kind) { /* Special endpoint. */
+ if (ep->engaged) {
+ endpoint_close_lua_cb(net, ep);
+ }
+ if (ep->fd > 0) {
+ close(ep->fd); /* nothing to do with errors */
+ }
+ free_const(ep->flags.kind);
+ return;
+ }
+
+ if (force) { /* Force close if event loop isn't running. */
+ if (ep->fd >= 0) {
+ close(ep->fd);
+ }
+ if (ep->handle) {
+ ep->handle->loop = NULL;
+ io_free(ep->handle);
+ }
+ } else { /* Asynchronous close */
+ uv_close(ep->handle, io_free);
+ }
+}
+
+/** Endpoint visitor (see @file map.h) */
+static int close_key(const char *key, void *val, void *net)
+{
+ endpoint_array_t *ep_array = val;
+ for (int i = 0; i < ep_array->len; ++i) {
+ endpoint_close(net, &ep_array->at[i], true);
+ }
+ return 0;
+}
+
+static int free_key(const char *key, void *val, void *ext)
+{
+ endpoint_array_t *ep_array = val;
+ array_clear(*ep_array);
+ free(ep_array);
+ return kr_ok();
+}
+
+int kind_unregister(trie_val_t *tv, void *L)
+{
+ int fun_id = (char *)*tv - (char *)NULL;
+ luaL_unref(L, LUA_REGISTRYINDEX, fun_id);
+ return 0;
+}
+
+void network_close_force(struct network *net)
+{
+ if (net != NULL) {
+ map_walk(&net->endpoints, close_key, net);
+ map_walk(&net->endpoints, free_key, 0);
+ map_clear(&net->endpoints);
+ }
+}
+
+void network_deinit(struct network *net)
+{
+ if (net != NULL) {
+ network_close_force(net);
+ trie_apply(net->endpoint_kinds, kind_unregister, the_worker->engine->L);
+ trie_free(net->endpoint_kinds);
+
+ tls_credentials_free(net->tls_credentials);
+ tls_client_params_free(net->tls_client_params);
+ tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx);
+ #ifndef NDEBUG
+ memset(net, 0, sizeof(*net));
+ #endif
+ }
+}
+
+/** Fetch or create endpoint array and insert endpoint (shallow memcpy). */
+static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep)
+{
+ /* Fetch or insert address into map */
+ endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
+ if (ep_array == NULL) {
+ ep_array = malloc(sizeof(*ep_array));
+ if (ep_array == NULL) {
+ return kr_error(ENOMEM);
+ }
+ if (map_set(&net->endpoints, addr, ep_array) != 0) {
+ free(ep_array);
+ return kr_error(ENOMEM);
+ }
+ array_init(*ep_array);
+ }
+
+ if (array_reserve(*ep_array, ep_array->len + 1)) {
+ return kr_error(ENOMEM);
+ }
+ memcpy(&ep_array->at[ep_array->len++], ep, sizeof(*ep));
+ return kr_ok();
+}
+
+/** Open endpoint protocols. ep->flags were pre-set. */
+static int open_endpoint(struct network *net, struct endpoint *ep,
+ const struct sockaddr *sa, const char *log_addr)
+{
+ if ((sa != NULL) == (ep->fd != -1)) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ if (ep->handle) {
+ return kr_error(EEXIST);
+ }
+
+ if (sa) {
+ ep->fd = io_bind(sa, ep->flags.sock_type, &ep->flags);
+ if (ep->fd < 0) return ep->fd;
+ }
+ if (ep->flags.kind) {
+ /* This EP isn't to be managed internally after binding. */
+ return endpoint_open_lua_cb(net, ep, log_addr);
+ } else {
+ ep->engaged = true;
+ /* .engaged seems not really meaningful with .kind == NULL, but... */
+ }
+
+ if (ep->family == AF_UNIX) {
+ /* Some parts of connection handling would need more work,
+ * so let's support AF_UNIX only with .kind != NULL for now. */
+ kr_log_error("[system] AF_UNIX only supported with set { kind = '...' }\n");
+ return kr_error(EAFNOSUPPORT);
+ /*
+ uv_pipe_t *ep_handle = malloc(sizeof(uv_pipe_t));
+ */
+ }
+
+ if (ep->flags.sock_type == SOCK_DGRAM) {
+ if (ep->flags.tls) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ uv_udp_t *ep_handle = malloc(sizeof(uv_udp_t));
+ ep->handle = (uv_handle_t *)ep_handle;
+ if (!ep->handle) {
+ return kr_error(ENOMEM);
+ }
+ return io_listen_udp(net->loop, ep_handle, ep->fd);
+ } /* else */
+
+ if (ep->flags.sock_type == SOCK_STREAM) {
+ uv_tcp_t *ep_handle = malloc(sizeof(uv_tcp_t));
+ ep->handle = (uv_handle_t *)ep_handle;
+ if (!ep->handle) {
+ return kr_error(ENOMEM);
+ }
+ return io_listen_tcp(net->loop, ep_handle, ep->fd,
+ net->tcp_backlog, ep->flags.tls);
+ } /* else */
+
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+}
+
+/** @internal Fetch a pointer to endpoint of given parameters (or NULL).
+ * Beware that there might be multiple matches, though that's not common. */
+static struct endpoint * endpoint_get(struct network *net, const char *addr,
+ uint16_t port, endpoint_flags_t flags)
+{
+ endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
+ if (!ep_array) {
+ return NULL;
+ }
+ for (int i = 0; i < ep_array->len; ++i) {
+ struct endpoint *ep = &ep_array->at[i];
+ if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) {
+ return ep;
+ }
+ }
+ return NULL;
+}
+
+/** \note pass either sa != NULL xor ep.fd != -1;
+ * \note ownership of ep.flags.* is taken on success. */
+static int create_endpoint(struct network *net, const char *addr_str,
+ struct endpoint *ep, const struct sockaddr *sa)
+{
+ int ret = open_endpoint(net, ep, sa, addr_str);
+ if (ret == 0) {
+ ret = insert_endpoint(net, addr_str, ep);
+ }
+ if (ret != 0 && ep->handle) {
+ endpoint_close(net, ep, false);
+ }
+ return ret;
+}
+
+int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags)
+{
+ /* Extract fd's socket type. */
+ socklen_t len = sizeof(flags.sock_type);
+ int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len);
+ if (ret != 0) {
+ return kr_error(errno);
+ }
+ if (flags.sock_type == SOCK_DGRAM && !flags.kind && flags.tls) {
+ assert(!EINVAL); /* Perhaps DTLS some day. */
+ return kr_error(EINVAL);
+ }
+ if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) {
+ return kr_error(EBADF);
+ }
+
+ /* Extract local address for this socket. */
+ struct sockaddr_storage ss = { .ss_family = AF_UNSPEC };
+ socklen_t addr_len = sizeof(ss);
+ ret = getsockname(fd, (struct sockaddr *)&ss, &addr_len);
+ if (ret != 0) {
+ return kr_error(errno);
+ }
+
+ struct endpoint ep = {
+ .flags = flags,
+ .family = ss.ss_family,
+ .fd = fd,
+ };
+ /* Extract address string and port. */
+ char addr_buf[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
+ const char *addr_str;
+ switch (ep.family) {
+ case AF_INET:
+ ret = uv_ip4_name((const struct sockaddr_in*)&ss, addr_buf, sizeof(addr_buf));
+ addr_str = addr_buf;
+ ep.port = ntohs(((struct sockaddr_in *)&ss)->sin_port);
+ break;
+ case AF_INET6:
+ ret = uv_ip6_name((const struct sockaddr_in6*)&ss, addr_buf, sizeof(addr_buf));
+ addr_str = addr_buf;
+ ep.port = ntohs(((struct sockaddr_in6 *)&ss)->sin6_port);
+ break;
+ case AF_UNIX:
+ /* No SOCK_DGRAM with AF_UNIX support, at least for now. */
+ ret = flags.sock_type == SOCK_STREAM ? kr_ok() : kr_error(EAFNOSUPPORT);
+ addr_str = ((struct sockaddr_un *)&ss)->sun_path;
+ break;
+ default:
+ ret = kr_error(EAFNOSUPPORT);
+ }
+ if (ret) return ret;
+
+ /* always create endpoint for supervisor supplied fd
+ * even if addr+port is not unique */
+ return create_endpoint(net, addr_str, &ep, NULL);
+}
+
+int network_listen(struct network *net, const char *addr, uint16_t port,
+ endpoint_flags_t flags)
+{
+ if (net == NULL || addr == 0 || port == 0) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ if (endpoint_get(net, addr, port, flags)) {
+ return kr_error(EADDRINUSE); /* Already listening */
+ }
+
+ /* Parse address. */
+ const struct sockaddr *sa = kr_straddr_socket(addr, port, NULL);
+ if (!sa) {
+ return kr_error(EINVAL);
+ }
+ struct endpoint ep = {
+ .flags = flags,
+ .fd = -1,
+ .port = port,
+ .family = sa->sa_family,
+ };
+ int ret = create_endpoint(net, addr, &ep, sa);
+ free_const(sa);
+ return ret;
+}
+
+int network_close(struct network *net, const char *addr, int port)
+{
+ endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
+ if (!ep_array) {
+ return kr_error(ENOENT);
+ }
+
+ size_t i = 0;
+ bool matched = false; /*< at least one match */
+ while (i < ep_array->len) {
+ struct endpoint *ep = &ep_array->at[i];
+ if (port < 0 || ep->port == port) {
+ endpoint_close(net, ep, false);
+ array_del(*ep_array, i);
+ matched = true;
+ /* do not advance i */
+ } else {
+ ++i;
+ }
+ }
+ if (!matched) {
+ return kr_error(ENOENT);
+ }
+
+ /* Collapse key if it has no endpoint. */
+ if (ep_array->len == 0) {
+ array_clear(*ep_array);
+ free(ep_array);
+ map_del(&net->endpoints, addr);
+ }
+
+ return kr_ok();
+}
+
+void network_new_hostname(struct network *net, struct engine *engine)
+{
+ if (net->tls_credentials &&
+ net->tls_credentials->ephemeral_servicename) {
+ struct tls_credentials *newcreds;
+ newcreds = tls_get_ephemeral_credentials(engine);
+ if (newcreds) {
+ tls_credentials_release(net->tls_credentials);
+ net->tls_credentials = newcreds;
+ kr_log_info("[tls] Updated ephemeral X.509 cert with new hostname\n");
+ } else {
+ kr_log_error("[tls] Failed to update ephemeral X.509 cert with new hostname, using existing one\n");
+ }
+ }
+}
+
+#ifdef SO_ATTACH_BPF
+static int set_bpf_cb(const char *key, void *val, void *ext)
+{
+ endpoint_array_t *endpoints = (endpoint_array_t *)val;
+ assert(endpoints != NULL);
+ int *bpffd = (int *)ext;
+ assert(bpffd != NULL);
+
+ for (size_t i = 0; i < endpoints->len; i++) {
+ struct endpoint *endpoint = &endpoints->at[i];
+ uv_os_fd_t sockfd = -1;
+ if (endpoint->handle != NULL)
+ uv_fileno(endpoint->handle, &sockfd);
+ assert(sockfd != -1);
+
+ if (setsockopt(sockfd, SOL_SOCKET, SO_ATTACH_BPF, bpffd, sizeof(int)) != 0) {
+ return 1; /* return error (and stop iterating over net->endpoints) */
+ }
+ }
+ return 0; /* OK */
+}
+#endif
+
+int network_set_bpf(struct network *net, int bpf_fd)
+{
+#ifdef SO_ATTACH_BPF
+ if (map_walk(&net->endpoints, set_bpf_cb, &bpf_fd) != 0) {
+ /* set_bpf_cb() has returned error. */
+ network_clear_bpf(net);
+ return 0;
+ }
+#else
+ kr_log_error("[network] SO_ATTACH_BPF socket option doesn't supported\n");
+ (void)net;
+ (void)bpf_fd;
+ return 0;
+#endif
+ return 1;
+}
+
+#ifdef SO_DETACH_BPF
+static int clear_bpf_cb(const char *key, void *val, void *ext)
+{
+ endpoint_array_t *endpoints = (endpoint_array_t *)val;
+ assert(endpoints != NULL);
+
+ for (size_t i = 0; i < endpoints->len; i++) {
+ struct endpoint *endpoint = &endpoints->at[i];
+ uv_os_fd_t sockfd = -1;
+ if (endpoint->handle != NULL)
+ uv_fileno(endpoint->handle, &sockfd);
+ assert(sockfd != -1);
+
+ if (setsockopt(sockfd, SOL_SOCKET, SO_DETACH_BPF, NULL, 0) != 0) {
+ kr_log_error("[network] failed to clear SO_DETACH_BPF socket option\n");
+ }
+ /* Proceed even if setsockopt() failed,
+ * as we want to process all opened sockets. */
+ }
+ return 0;
+}
+#endif
+
+void network_clear_bpf(struct network *net)
+{
+#ifdef SO_DETACH_BPF
+ map_walk(&net->endpoints, clear_bpf_cb, NULL);
+#else
+ kr_log_error("[network] SO_DETACH_BPF socket option doesn't supported\n");
+ (void)net;
+#endif
+}
--- /dev/null
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include "tls.h"
+
+#include "lib/generic/array.h"
+#include "lib/generic/map.h"
+#include "lib/generic/trie.h"
+
+#include <uv.h>
+#include <stdbool.h>
+
+
+struct engine;
+
+/** Ways to listen on a socket. */
+typedef struct {
+ int sock_type; /**< SOCK_DGRAM or SOCK_STREAM */
+ bool tls; /**< only used together with .kind == NULL and .tcp */
+ const char *kind; /**< tag for other types than the three usual */
+ bool freebind; /**< used for binding to non-local address **/
+} endpoint_flags_t;
+
+static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2)
+{
+ if (f1.sock_type != f2.sock_type)
+ return false;
+ if (f1.kind && f2.kind)
+ return strcasecmp(f1.kind, f2.kind);
+ else
+ return f1.tls == f2.tls && f1.kind == f2.kind;
+}
+
+/** Wrapper for a single socket to listen on.
+ * There are two types: normal have handle, special have flags.kind (and never both).
+ *
+ * LATER: .family might be unexpected for IPv4-in-IPv6 addresses.
+ * ATM AF_UNIX is only supported with flags.kind != NULL
+ */
+struct endpoint {
+ uv_handle_t *handle; /**< uv_udp_t or uv_tcp_t; NULL in case flags.kind != NULL */
+ int fd; /**< POSIX file-descriptor; always used. */
+ int family; /**< AF_INET or AF_INET6 or AF_UNIX */
+ uint16_t port; /**< TCP/UDP port. Meaningless with AF_UNIX. */
+ bool engaged; /**< to some module or internally */
+ endpoint_flags_t flags;
+};
+
+/** @cond internal Array of endpoints */
+typedef array_t(struct endpoint) endpoint_array_t;
+/* @endcond */
+
+struct net_tcp_param {
+ uint64_t in_idle_timeout;
+ uint64_t tls_handshake_timeout;
+};
+
+struct network {
+ uv_loop_t *loop;
+
+ /** Map: address string -> endpoint_array_t.
+ * \note even same address-port-flags tuples may appear.
+ * TODO: trie_t, keyed on *binary* address-port pair. */
+ map_t endpoints;
+
+ /** Registry of callbacks for special endpoint kinds (for opening/closing).
+ * Map: kind (lowercased) -> lua function ID converted to void *
+ * The ID is the usual: raw int index in the LUA_REGISTRYINDEX table. */
+ trie_t *endpoint_kinds;
+ /** See network_engage_endpoints() */
+ bool missing_kind_is_error;
+
+ struct tls_credentials *tls_credentials;
+ tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */
+ struct tls_session_ticket_ctx *tls_session_ticket_ctx;
+ struct net_tcp_param tcp;
+ int tcp_backlog;
+};
+
+void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog);
+void network_deinit(struct network *net);
+
+/** Start listenting on addr#port with flags.
+ * \note if we did listen on that combination already,
+ * nothing is done and kr_error(EADDRINUSE) is returned.
+ * \note there's no short-hand to listen both on UDP and TCP.
+ * \note ownership of flags.* is taken on success. TODO: non-success?
+ */
+int network_listen(struct network *net, const char *addr, uint16_t port,
+ endpoint_flags_t flags);
+
+/** Start listenting on an open file-descriptor.
+ * \note flags.sock_type isn't meaningful here.
+ * \note ownership of flags.* is taken on success. TODO: non-success?
+ */
+int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags);
+
+/** Stop listening on all endpoints with matching addr#port.
+ * port < 0 serves as a wild-card.
+ * \return kr_error(ENOENT) if nothing matched. */
+int network_close(struct network *net, const char *addr, int port);
+
+/** Close all endpoints immediately (no waiting for UV loop). */
+void network_close_force(struct network *net);
+
+/** Enforce that all endpoints are registered from now on.
+ * This only does anything with struct endpoint::flags.kind != NULL. */
+int network_engage_endpoints(struct network *net);
+
+int network_set_tls_cert(struct network *net, const char *cert);
+int network_set_tls_key(struct network *net, const char *key);
+void network_new_hostname(struct network *net, struct engine *engine);
+int network_set_bpf(struct network *net, int bpf_fd);
+void network_clear_bpf(struct network *net);
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <assert.h>
+
+#include <libknot/packet/pkt.h>
+
+#include "lib/defines.h"
+#include "session.h"
+#include "engine.h"
+#include "tls.h"
+#include "worker.h"
+#include "io.h"
+#include "lib/generic/queue.h"
+
+#define TLS_CHUNK_SIZE (16 * 1024)
+
+/* Per-socket (TCP or UDP) persistent structure.
+ *
+ * In particular note that for UDP clients it's just one session (per socket)
+ * shared for all clients. For TCP/TLS it's for the connection-specific socket,
+ * i.e one session per connection.
+ */
+struct session {
+ struct session_flags sflags; /**< miscellaneous flags. */
+ union inaddr peer; /**< address of peer; not for UDP clients (downstream) */
+ union inaddr sockname; /**< our local address; for UDP it may be a wildcard */
+ uv_handle_t *handle; /**< libuv handle for IO operations. */
+ uv_timer_t timeout; /**< libuv handle for timer. */
+
+ struct tls_ctx_t *tls_ctx; /**< server side tls-related data. */
+ struct tls_client_ctx_t *tls_client_ctx; /**< client side tls-related data. */
+
+ trie_t *tasks; /**< list of tasks assotiated with given session. */
+ queue_t(struct qr_task *) waiting; /**< list of tasks waiting for sending to upstream. */
+
+ uint8_t *wire_buf; /**< Buffer for DNS message. */
+ ssize_t wire_buf_size; /**< Buffer size. */
+ ssize_t wire_buf_start_idx; /**< Data start offset in wire_buf. */
+ ssize_t wire_buf_end_idx; /**< Data end offset in wire_buf. */
+ uint64_t last_activity; /**< Time of last IO activity (if any occurs).
+ * Otherwise session creation time. */
+};
+
+static void on_session_close(uv_handle_t *handle)
+{
+ struct session *session = handle->data;
+ assert(session->handle == handle); (void)session;
+ io_free(handle);
+}
+
+static void on_session_timer_close(uv_handle_t *timer)
+{
+ struct session *session = timer->data;
+ uv_handle_t *handle = session->handle;
+ assert(handle && handle->data == session);
+ assert (session->sflags.outgoing || handle->type == UV_TCP);
+ if (!uv_is_closing(handle)) {
+ uv_close(handle, on_session_close);
+ }
+}
+
+void session_free(struct session *session)
+{
+ if (session) {
+ assert(session_is_empty(session));
+ session_clear(session);
+ free(session);
+ }
+}
+
+void session_clear(struct session *session)
+{
+ assert(session_is_empty(session));
+ if (session->handle && session->handle->type == UV_TCP) {
+ free(session->wire_buf);
+ }
+ trie_clear(session->tasks);
+ trie_free(session->tasks);
+ queue_deinit(session->waiting);
+ tls_free(session->tls_ctx);
+ tls_client_ctx_free(session->tls_client_ctx);
+ memset(session, 0, sizeof(*session));
+}
+
+void session_close(struct session *session)
+{
+ assert(session_is_empty(session));
+ if (session->sflags.closing) {
+ return;
+ }
+
+ uv_handle_t *handle = session->handle;
+ io_stop_read(handle);
+ session->sflags.closing = true;
+
+ if (!uv_is_closing((uv_handle_t *)&session->timeout)) {
+ uv_timer_stop(&session->timeout);
+ if (session->tls_client_ctx) {
+ tls_close(&session->tls_client_ctx->c);
+ }
+ if (session->tls_ctx) {
+ tls_close(&session->tls_ctx->c);
+ }
+
+ session->timeout.data = session;
+ uv_close((uv_handle_t *)&session->timeout, on_session_timer_close);
+ }
+}
+
+int session_start_read(struct session *session)
+{
+ return io_start_read(session->handle);
+}
+
+int session_stop_read(struct session *session)
+{
+ return io_stop_read(session->handle);
+}
+
+int session_waitinglist_push(struct session *session, struct qr_task *task)
+{
+ queue_push(session->waiting, task);
+ worker_task_ref(task);
+ return kr_ok();
+}
+
+struct qr_task *session_waitinglist_get(const struct session *session)
+{
+ return (queue_len(session->waiting) > 0) ? (queue_head(session->waiting)) : NULL;
+}
+
+struct qr_task *session_waitinglist_pop(struct session *session, bool deref)
+{
+ struct qr_task *t = session_waitinglist_get(session);
+ queue_pop(session->waiting);
+ if (deref) {
+ worker_task_unref(t);
+ }
+ return t;
+}
+
+int session_tasklist_add(struct session *session, struct qr_task *task)
+{
+ trie_t *t = session->tasks;
+ uint16_t task_msg_id = 0;
+ const char *key = NULL;
+ size_t key_len = 0;
+ if (session->sflags.outgoing) {
+ knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
+ task_msg_id = knot_wire_get_id(pktbuf->wire);
+ key = (const char *)&task_msg_id;
+ key_len = sizeof(task_msg_id);
+ } else {
+ key = (const char *)&task;
+ key_len = sizeof(char *);
+ }
+ trie_val_t *v = trie_get_ins(t, key, key_len);
+ if (unlikely(!v)) {
+ assert(false);
+ return kr_error(ENOMEM);
+ }
+ if (*v == NULL) {
+ *v = task;
+ worker_task_ref(task);
+ } else if (*v != task) {
+ assert(false);
+ return kr_error(EINVAL);
+ }
+ return kr_ok();
+}
+
+int session_tasklist_del(struct session *session, struct qr_task *task)
+{
+ trie_t *t = session->tasks;
+ uint16_t task_msg_id = 0;
+ const char *key = NULL;
+ size_t key_len = 0;
+ trie_val_t val;
+ if (session->sflags.outgoing) {
+ knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
+ task_msg_id = knot_wire_get_id(pktbuf->wire);
+ key = (const char *)&task_msg_id;
+ key_len = sizeof(task_msg_id);
+ } else {
+ key = (const char *)&task;
+ key_len = sizeof(char *);
+ }
+ int ret = trie_del(t, key, key_len, &val);
+ if (ret == kr_ok()) {
+ assert(val == task);
+ worker_task_unref(val);
+ }
+ return ret;
+}
+
+struct qr_task *session_tasklist_get_first(struct session *session)
+{
+ trie_val_t *val = trie_get_first(session->tasks, NULL, NULL);
+ return val ? (struct qr_task *) *val : NULL;
+}
+
+struct qr_task *session_tasklist_del_first(struct session *session, bool deref)
+{
+ trie_val_t val = NULL;
+ int res = trie_del_first(session->tasks, NULL, NULL, &val);
+ if (res != kr_ok()) {
+ val = NULL;
+ } else if (deref) {
+ worker_task_unref(val);
+ }
+ return (struct qr_task *)val;
+}
+struct qr_task* session_tasklist_del_msgid(const struct session *session, uint16_t msg_id)
+{
+ trie_t *t = session->tasks;
+ assert(session->sflags.outgoing);
+ struct qr_task *ret = NULL;
+ const char *key = (const char *)&msg_id;
+ size_t key_len = sizeof(msg_id);
+ trie_val_t val;
+ int res = trie_del(t, key, key_len, &val);
+ if (res == kr_ok()) {
+ if (worker_task_numrefs(val) > 1) {
+ ret = val;
+ }
+ worker_task_unref(val);
+ }
+ return ret;
+}
+
+struct qr_task* session_tasklist_find_msgid(const struct session *session, uint16_t msg_id)
+{
+ trie_t *t = session->tasks;
+ assert(session->sflags.outgoing);
+ struct qr_task *ret = NULL;
+ trie_val_t *val = trie_get_try(t, (char *)&msg_id, sizeof(msg_id));
+ if (val) {
+ ret = *val;
+ }
+ return ret;
+}
+
+struct session_flags *session_flags(struct session *session)
+{
+ return &session->sflags;
+}
+
+struct sockaddr *session_get_peer(struct session *session)
+{
+ return &session->peer.ip;
+}
+
+struct sockaddr *session_get_sockname(struct session *session)
+{
+ return &session->sockname.ip;
+}
+
+struct tls_ctx_t *session_tls_get_server_ctx(const struct session *session)
+{
+ return session->tls_ctx;
+}
+
+void session_tls_set_server_ctx(struct session *session, struct tls_ctx_t *ctx)
+{
+ session->tls_ctx = ctx;
+}
+
+struct tls_client_ctx_t *session_tls_get_client_ctx(const struct session *session)
+{
+ return session->tls_client_ctx;
+}
+
+void session_tls_set_client_ctx(struct session *session, struct tls_client_ctx_t *ctx)
+{
+ session->tls_client_ctx = ctx;
+}
+
+struct tls_common_ctx *session_tls_get_common_ctx(const struct session *session)
+{
+ struct tls_common_ctx *tls_ctx = session->sflags.outgoing ? &session->tls_client_ctx->c :
+ &session->tls_ctx->c;
+ return tls_ctx;
+}
+
+uv_handle_t *session_get_handle(struct session *session)
+{
+ return session->handle;
+}
+
+struct session *session_get(uv_handle_t *h)
+{
+ return h->data;
+}
+
+struct session *session_new(uv_handle_t *handle, bool has_tls)
+{
+ if (!handle) {
+ return NULL;
+ }
+ struct session *session = calloc(1, sizeof(struct session));
+ if (!session) {
+ return NULL;
+ }
+
+ queue_init(session->waiting);
+ session->tasks = trie_create(NULL);
+ if (handle->type == UV_TCP) {
+ size_t wire_buffer_size = KNOT_WIRE_MAX_PKTSIZE;
+ if (has_tls) {
+ /* When decoding large packets,
+ * gnutls gives the application chunks of size 16 kb each. */
+ wire_buffer_size += TLS_CHUNK_SIZE;
+ session->sflags.has_tls = true;
+ }
+ uint8_t *wire_buf = malloc(wire_buffer_size);
+ if (!wire_buf) {
+ free(session);
+ return NULL;
+ }
+ session->wire_buf = wire_buf;
+ session->wire_buf_size = wire_buffer_size;
+ } else if (handle->type == UV_UDP) {
+ /* We use the singleton buffer from worker for all UDP (!)
+ * libuv documentation doesn't really guarantee this is OK,
+ * but the implementation for unix systems does not hold
+ * the buffer (both UDP and TCP) - always makes a NON-blocking
+ * syscall that fills the buffer and immediately calls
+ * the callback, whatever the result of the operation.
+ * We still need to keep in mind to only touch the buffer
+ * in this callback... */
+ assert(handle->loop->data);
+ struct worker_ctx *worker = handle->loop->data;
+ session->wire_buf = worker->wire_buf;
+ session->wire_buf_size = sizeof(worker->wire_buf);
+ }
+
+ uv_timer_init(handle->loop, &session->timeout);
+
+ session->handle = handle;
+ handle->data = session;
+ session->timeout.data = session;
+ session_touch(session);
+
+ return session;
+}
+
+size_t session_tasklist_get_len(const struct session *session)
+{
+ return trie_weight(session->tasks);
+}
+
+size_t session_waitinglist_get_len(const struct session *session)
+{
+ return queue_len(session->waiting);
+}
+
+bool session_tasklist_is_empty(const struct session *session)
+{
+ return session_tasklist_get_len(session) == 0;
+}
+
+bool session_waitinglist_is_empty(const struct session *session)
+{
+ return session_waitinglist_get_len(session) == 0;
+}
+
+bool session_is_empty(const struct session *session)
+{
+ return session_tasklist_is_empty(session) &&
+ session_waitinglist_is_empty(session);
+}
+
+bool session_has_tls(const struct session *session)
+{
+ return session->sflags.has_tls;
+}
+
+void session_set_has_tls(struct session *session, bool has_tls)
+{
+ session->sflags.has_tls = has_tls;
+}
+
+void session_waitinglist_retry(struct session *session, bool increase_timeout_cnt)
+{
+ while (!session_waitinglist_is_empty(session)) {
+ struct qr_task *task = session_waitinglist_pop(session, false);
+ if (increase_timeout_cnt) {
+ worker_task_timeout_inc(task);
+ }
+ worker_task_step(task, &session->peer.ip, NULL);
+ worker_task_unref(task);
+ }
+}
+
+void session_waitinglist_finalize(struct session *session, int status)
+{
+ while (!session_waitinglist_is_empty(session)) {
+ struct qr_task *t = session_waitinglist_pop(session, false);
+ worker_task_finalize(t, status);
+ worker_task_unref(t);
+ }
+}
+
+void session_tasklist_finalize(struct session *session, int status)
+{
+ while (session_tasklist_get_len(session) > 0) {
+ struct qr_task *t = session_tasklist_del_first(session, false);
+ assert(worker_task_numrefs(t) > 0);
+ worker_task_finalize(t, status);
+ worker_task_unref(t);
+ }
+}
+
+int session_tasklist_finalize_expired(struct session *session)
+{
+ int ret = 0;
+ queue_t(struct qr_task *) q;
+ uint64_t now = kr_now();
+ trie_t *t = session->tasks;
+ trie_it_t *it;
+ queue_init(q);
+ for (it = trie_it_begin(t); !trie_it_finished(it); trie_it_next(it)) {
+ trie_val_t *v = trie_it_val(it);
+ struct qr_task *task = (struct qr_task *)*v;
+ if ((now - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) {
+ queue_push(q, task);
+ worker_task_ref(task);
+ }
+ }
+ trie_it_free(it);
+
+ struct qr_task *task = NULL;
+ uint16_t msg_id = 0;
+ char *key = (char *)&task;
+ int32_t keylen = sizeof(struct qr_task *);
+ if (session->sflags.outgoing) {
+ key = (char *)&msg_id;
+ keylen = sizeof(msg_id);
+ }
+ while (queue_len(q) > 0) {
+ task = queue_head(q);
+ if (session->sflags.outgoing) {
+ knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
+ msg_id = knot_wire_get_id(pktbuf->wire);
+ }
+ int res = trie_del(t, key, keylen, NULL);
+ if (!worker_task_finished(task)) {
+ /* task->pending_count must be zero,
+ * but there are can be followers,
+ * so run worker_task_subreq_finalize() to ensure retrying
+ * for all the followers. */
+ worker_task_subreq_finalize(task);
+ worker_task_finalize(task, KR_STATE_FAIL);
+ }
+ if (res == KNOT_EOK) {
+ worker_task_unref(task);
+ }
+ queue_pop(q);
+ worker_task_unref(task);
+ ++ret;
+ }
+
+ queue_deinit(q);
+ return ret;
+}
+
+int session_timer_start(struct session *session, uv_timer_cb cb,
+ uint64_t timeout, uint64_t repeat)
+{
+ uv_timer_t *timer = &session->timeout;
+ assert(timer->data == session);
+ int ret = uv_timer_start(timer, cb, timeout, repeat);
+ if (ret != 0) {
+ uv_timer_stop(timer);
+ return kr_error(ENOMEM);
+ }
+ return 0;
+}
+
+int session_timer_restart(struct session *session)
+{
+ return uv_timer_again(&session->timeout);
+}
+
+int session_timer_stop(struct session *session)
+{
+ return uv_timer_stop(&session->timeout);
+}
+
+ssize_t session_wirebuf_consume(struct session *session, const uint8_t *data, ssize_t len)
+{
+ if (data != &session->wire_buf[session->wire_buf_end_idx]) {
+ /* shouldn't happen */
+ return kr_error(EINVAL);
+ }
+
+ if (len < 0) {
+ /* shouldn't happen */
+ return kr_error(EINVAL);
+ }
+
+ if (session->wire_buf_end_idx + len > session->wire_buf_size) {
+ /* shouldn't happen */
+ return kr_error(EINVAL);
+ }
+
+ session->wire_buf_end_idx += len;
+ return len;
+}
+
+knot_pkt_t *session_produce_packet(struct session *session, knot_mm_t *mm)
+{
+ session->sflags.wirebuf_error = false;
+ if (session->wire_buf_end_idx == 0) {
+ return NULL;
+ }
+
+ if (session->wire_buf_start_idx == session->wire_buf_end_idx) {
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+ return NULL;
+ }
+
+ if (session->wire_buf_start_idx > session->wire_buf_end_idx) {
+ session->sflags.wirebuf_error = true;
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+ return NULL;
+ }
+
+ const uv_handle_t *handle = session->handle;
+ uint8_t *msg_start = &session->wire_buf[session->wire_buf_start_idx];
+ ssize_t wirebuf_msg_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
+ uint16_t msg_size = 0;
+
+ if (!handle) {
+ session->sflags.wirebuf_error = true;
+ return NULL;
+ } else if (handle->type == UV_TCP) {
+ if (wirebuf_msg_data_size < 2) {
+ return NULL;
+ }
+ msg_size = knot_wire_read_u16(msg_start);
+ if (msg_size >= session->wire_buf_size) {
+ session->sflags.wirebuf_error = true;
+ return NULL;
+ }
+ if (msg_size + 2 > wirebuf_msg_data_size) {
+ return NULL;
+ }
+ if (msg_size == 0) {
+ session->sflags.wirebuf_error = true;
+ return NULL;
+ }
+ msg_start += 2;
+ } else if (wirebuf_msg_data_size < UINT16_MAX) {
+ msg_size = wirebuf_msg_data_size;
+ } else {
+ session->sflags.wirebuf_error = true;
+ return NULL;
+ }
+
+
+ knot_pkt_t *pkt = knot_pkt_new(msg_start, msg_size, mm);
+ session->sflags.wirebuf_error = (pkt == NULL);
+ return pkt;
+}
+
+int session_discard_packet(struct session *session, const knot_pkt_t *pkt)
+{
+ uv_handle_t *handle = session->handle;
+ /* Pointer to data start in wire_buf */
+ uint8_t *wirebuf_data_start = &session->wire_buf[session->wire_buf_start_idx];
+ /* Number of data bytes in wire_buf */
+ size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
+ /* Pointer to message start in wire_buf */
+ uint8_t *wirebuf_msg_start = wirebuf_data_start;
+ /* Number of message bytes in wire_buf.
+ * For UDP it is the same number as wirebuf_data_size. */
+ size_t wirebuf_msg_size = wirebuf_data_size;
+ /* Wire data from parsed packet. */
+ uint8_t *pkt_msg_start = pkt->wire;
+ /* Number of bytes in packet wire buffer. */
+ size_t pkt_msg_size = pkt->size;
+ if (knot_pkt_has_tsig(pkt)) {
+ pkt_msg_size += pkt->tsig_wire.len;
+ }
+
+ session->sflags.wirebuf_error = true;
+
+ if (!handle) {
+ return kr_error(EINVAL);
+ } else if (handle->type == UV_TCP) {
+ /* wire_buf contains TCP DNS message. */
+ if (wirebuf_data_size < 2) {
+ /* TCP message length field isn't in buffer, must not happen. */
+ assert(0);
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+ return kr_error(EINVAL);
+ }
+ wirebuf_msg_size = knot_wire_read_u16(wirebuf_msg_start);
+ wirebuf_msg_start += 2;
+ if (wirebuf_msg_size + 2 > wirebuf_data_size) {
+ /* TCP message length field is greater then
+ * number of bytes in buffer, must not happen. */
+ assert(0);
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+ return kr_error(EINVAL);
+ }
+ }
+
+ if (wirebuf_msg_start != pkt_msg_start) {
+ /* packet wirebuf must be located at the beginning
+ * of the session wirebuf, must not happen. */
+ assert(0);
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+ return kr_error(EINVAL);
+ }
+
+ if (wirebuf_msg_size < pkt_msg_size) {
+ /* Message length field is lesser then packet size,
+ * must not happen. */
+ assert(0);
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+ return kr_error(EINVAL);
+ }
+
+ if (handle->type == UV_TCP) {
+ session->wire_buf_start_idx += wirebuf_msg_size + 2;
+ } else {
+ session->wire_buf_start_idx += pkt_msg_size;
+ }
+ session->sflags.wirebuf_error = false;
+
+ wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
+ if (wirebuf_data_size == 0) {
+ session_wirebuf_discard(session);
+ } else if (wirebuf_data_size < KNOT_WIRE_HEADER_SIZE) {
+ session_wirebuf_compress(session);
+ }
+
+ return kr_ok();
+}
+
+void session_wirebuf_discard(struct session *session)
+{
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = 0;
+}
+
+void session_wirebuf_compress(struct session *session)
+{
+ if (session->wire_buf_start_idx == 0) {
+ return;
+ }
+ uint8_t *wirebuf_data_start = &session->wire_buf[session->wire_buf_start_idx];
+ size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
+ if (session->wire_buf_start_idx < wirebuf_data_size) {
+ memmove(session->wire_buf, wirebuf_data_start, wirebuf_data_size);
+ } else {
+ memcpy(session->wire_buf, wirebuf_data_start, wirebuf_data_size);
+ }
+ session->wire_buf_start_idx = 0;
+ session->wire_buf_end_idx = wirebuf_data_size;
+}
+
+bool session_wirebuf_error(struct session *session)
+{
+ return session->sflags.wirebuf_error;
+}
+
+uint8_t *session_wirebuf_get_start(struct session *session)
+{
+ return session->wire_buf;
+}
+
+size_t session_wirebuf_get_size(struct session *session)
+{
+ return session->wire_buf_size;
+}
+
+uint8_t *session_wirebuf_get_free_start(struct session *session)
+{
+ return &session->wire_buf[session->wire_buf_end_idx];
+}
+
+size_t session_wirebuf_get_free_size(struct session *session)
+{
+ return session->wire_buf_size - session->wire_buf_end_idx;
+}
+
+void session_poison(struct session *session)
+{
+ kr_asan_poison(session, sizeof(*session));
+}
+
+void session_unpoison(struct session *session)
+{
+ kr_asan_unpoison(session, sizeof(*session));
+}
+
+int session_wirebuf_process(struct session *session, const struct sockaddr *peer)
+{
+ int ret = 0;
+ if (session->wire_buf_start_idx == session->wire_buf_end_idx) {
+ return ret;
+ }
+ struct worker_ctx *worker = session_get_handle(session)->loop->data;
+ size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
+ uint32_t max_iterations = (wirebuf_data_size / (KNOT_WIRE_HEADER_SIZE + KNOT_WIRE_QUESTION_MIN_SIZE)) + 1;
+ knot_pkt_t *query = NULL;
+ while (((query = session_produce_packet(session, &worker->pkt_pool)) != NULL) &&
+ (ret < max_iterations)) {
+ assert (!session_wirebuf_error(session));
+ int res = worker_submit(session, peer, query);
+ if (res != kr_error(EILSEQ)) {
+ /* Packet has been successfully parsed. */
+ ret += 1;
+ }
+ if (session_discard_packet(session, query) < 0) {
+ /* Packet data isn't stored in memory as expected.
+ something went wrong, normally should not happen. */
+ break;
+ }
+ }
+ if (session_wirebuf_error(session)) {
+ ret = -1;
+ }
+ return ret;
+}
+
+void session_kill_ioreq(struct session *s, struct qr_task *task)
+{
+ if (!s) {
+ return;
+ }
+ assert(s->sflags.outgoing && s->handle);
+ if (s->sflags.closing) {
+ return;
+ }
+ session_tasklist_del(s, task);
+ if (s->handle->type == UV_UDP) {
+ assert(session_tasklist_is_empty(s));
+ session_close(s);
+ return;
+ }
+}
+
+/** Update timestamp */
+void session_touch(struct session *s)
+{
+ s->last_activity = kr_now();
+}
+
+uint64_t session_last_activity(struct session *s)
+{
+ return s->last_activity;
+}
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <libknot/packet/pkt.h>
+
+#include <stdbool.h>
+#include <uv.h>
+
+struct qr_task;
+struct worker_ctx;
+struct session;
+
+struct session_flags {
+ bool outgoing : 1; /**< True: to upstream; false: from a client. */
+ bool throttled : 1; /**< True: data reading from peer is temporarily stopped. */
+ bool has_tls : 1; /**< True: given session uses TLS. */
+ bool connected : 1; /**< True: TCP connection is established. */
+ bool closing : 1; /**< True: session close sequence is in progress. */
+ bool wirebuf_error : 1; /**< True: last operation with wirebuf ended up with an error. */
+};
+
+/* Allocate new session for a libuv handle.
+ * If handle->tyoe is UV_UDP, tls parameter will be ignored. */
+struct session *session_new(uv_handle_t *handle, bool has_tls);
+/* Clear and free given session. */
+void session_free(struct session *session);
+/* Clear session. */
+void session_clear(struct session *session);
+/** Close session. */
+void session_close(struct session *session);
+/** Start reading from underlying libuv IO handle. */
+int session_start_read(struct session *session);
+/** Stop reading from underlying libuv IO handle. */
+int session_stop_read(struct session *session);
+
+/** List of tasks been waiting for IO. */
+/** Check if list is empty. */
+bool session_waitinglist_is_empty(const struct session *session);
+/** Add task to the end of the list. */
+int session_waitinglist_push(struct session *session, struct qr_task *task);
+/** Get the first element. */
+struct qr_task *session_waitinglist_get(const struct session *session);
+/** Get the first element and remove it from the list. */
+struct qr_task *session_waitinglist_pop(struct session *session, bool deref);
+/** Get the list length. */
+size_t session_waitinglist_get_len(const struct session *session);
+/** Retry resolution for each task in the list. */
+void session_waitinglist_retry(struct session *session, bool increase_timeout_cnt);
+/** Finalize all tasks in the list. */
+void session_waitinglist_finalize(struct session *session, int status);
+
+/** List of tasks associated with session. */
+/** Check if list is empty. */
+bool session_tasklist_is_empty(const struct session *session);
+/** Get the first element. */
+struct qr_task *session_tasklist_get_first(struct session *session);
+/** Get the first element and remove it from the list. */
+struct qr_task *session_tasklist_del_first(struct session *session, bool deref);
+/** Get the list length. */
+size_t session_tasklist_get_len(const struct session *session);
+/** Add task to the list. */
+int session_tasklist_add(struct session *session, struct qr_task *task);
+/** Remove task from the list. */
+int session_tasklist_del(struct session *session, struct qr_task *task);
+/** Remove task with given msg_id, session_flags(session)->outgoing must be true. */
+struct qr_task* session_tasklist_del_msgid(const struct session *session, uint16_t msg_id);
+/** Find task with given msg_id */
+struct qr_task* session_tasklist_find_msgid(const struct session *session, uint16_t msg_id);
+/** Finalize all tasks in the list. */
+void session_tasklist_finalize(struct session *session, int status);
+/** Finalize all expired tasks in the list. */
+int session_tasklist_finalize_expired(struct session *session);
+
+/** Both of task lists (associated & waiting). */
+/** Check if empty. */
+bool session_is_empty(const struct session *session);
+/** Get pointer to session flags */
+struct session_flags *session_flags(struct session *session);
+/** Get pointer to peer address. */
+struct sockaddr *session_get_peer(struct session *session);
+/** Get pointer to sockname (address of our end, not meaningful for UDP downstream). */
+struct sockaddr *session_get_sockname(struct session *session);
+/** Get pointer to server-side tls-related data. */
+struct tls_ctx_t *session_tls_get_server_ctx(const struct session *session);
+/** Set pointer to server-side tls-related data. */
+void session_tls_set_server_ctx(struct session *session, struct tls_ctx_t *ctx);
+/** Get pointer to client-side tls-related data. */
+struct tls_client_ctx_t *session_tls_get_client_ctx(const struct session *session);
+/** Set pointer to client-side tls-related data. */
+void session_tls_set_client_ctx(struct session *session, struct tls_client_ctx_t *ctx);
+/** Get pointer to that part of tls-related data which has common structure for
+ * server and client. */
+struct tls_common_ctx *session_tls_get_common_ctx(const struct session *session);
+
+/** Get pointer to underlying libuv handle for IO operations. */
+uv_handle_t *session_get_handle(struct session *session);
+struct session *session_get(uv_handle_t *h);
+
+/** Start session timer. */
+int session_timer_start(struct session *session, uv_timer_cb cb,
+ uint64_t timeout, uint64_t repeat);
+/** Restart session timer without changing it parameters. */
+int session_timer_restart(struct session *session);
+/** Stop session timer. */
+int session_timer_stop(struct session *session);
+
+/** Get pointer to the beginning of session wirebuffer. */
+uint8_t *session_wirebuf_get_start(struct session *session);
+/** Get size of session wirebuffer. */
+size_t session_wirebuf_get_size(struct session *session);
+/** Get pointer to the beginning of free space in session wirebuffer. */
+uint8_t *session_wirebuf_get_free_start(struct session *session);
+/** Get amount of free space in session wirebuffer. */
+size_t session_wirebuf_get_free_size(struct session *session);
+/** Discard all data in session wirebuffer. */
+void session_wirebuf_discard(struct session *session);
+/** Move all data to the beginning of the buffer. */
+void session_wirebuf_compress(struct session *session);
+int session_wirebuf_process(struct session *session, const struct sockaddr *peer);
+ssize_t session_wirebuf_consume(struct session *session,
+ const uint8_t *data, ssize_t len);
+
+/** poison session structure with ASAN. */
+void session_poison(struct session *session);
+/** unpoison session structure with ASAN. */
+void session_unpoison(struct session *session);
+
+knot_pkt_t *session_produce_packet(struct session *session, knot_mm_t *mm);
+int session_discard_packet(struct session *session, const knot_pkt_t *pkt);
+
+void session_kill_ioreq(struct session *s, struct qr_task *task);
+/** Update timestamp */
+void session_touch(struct session *s);
+/** Returns either creation time or time of last IO activity if any occurs. */
+/* Used for TCP timeout calculation. */
+uint64_t session_last_activity(struct session *s);
--- /dev/null
+#include <lua.h>
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#include "kresconfig.h"
+
+#include "sr_subscriptions.h"
+#include "dbus_control.h"
+#include "watcher.h"
+#include "worker.h"
+
+#include "modules/sysrepo/common/sysrepo.h"
+#include "modules/sysrepo/common/string_helper.h"
+
+#define XPATH_SERVER XPATH_BASE"/server"
+#define XPATH_INSTANCES XPATH_BASE"/"YM_KRES":instances"
+#define XPATH_TST_SECRET XPATH_BASE"/network/tls/"YM_KRES":sticket-secret"
+#define XPATH_CACHE_PREFILL XPATH_BASE"/cache/"YM_KRES":prefill"
+#define XPATH_STATUS XPATH_SERVER"/"YM_KRES":status"
+#define XPATH_VERSION XPATH_SERVER"/package-version"
+
+
+static int kresd_instances_start(const char *method)
+{
+ char xpath[128];
+ int ret = SR_ERR_OK;
+ sr_val_t *vals = NULL;
+ size_t i, val_count = 0;
+ sysrepo_uv_ctx_t *sysrepo = the_worker->engine->watcher.sysrepo;
+ int kresd_instances = the_worker->engine->watcher.config.kresd_instances;
+
+ ret = sr_get_items(sysrepo->session, XPATH_BASE"/"YM_KRES":instances//name", 0, 0, &vals, &val_count);
+
+ for (i = 0; i < kresd_instances; ++i) {
+
+ char inst_name[128];
+ if (i < val_count)
+ sprintf(&inst_name, "%s", vals[i].data.string_val);
+ else
+ sprintf(&inst_name, "%ld", i);
+
+ kresd_ctl(method,inst_name);
+ }
+ sr_free_values(vals, val_count);
+}
+
+static int kresd_instances_status(struct lyd_node **parent)
+{
+ char xpath[128];
+ int ret = SR_ERR_OK;
+ sr_val_t *vals = NULL;
+ size_t i, val_count = 0;
+ sysrepo_uv_ctx_t *sysrepo = the_worker->engine->watcher.sysrepo;
+ int kresd_instances = the_worker->engine->watcher.config.kresd_instances;
+
+ ret = sr_get_items(sysrepo->session, XPATH_BASE"/"YM_KRES":instances//name", 0, 0, &vals, &val_count);
+
+ for (i = 0; i < kresd_instances; ++i) {
+
+ char inst_name[128];
+ char *status;
+
+ if (i < val_count)
+ sprintf(&inst_name, "%s", vals[i].data.string_val);
+ else
+ sprintf(&inst_name, "%ld", i);
+
+ sprintf(&xpath, XPATH_STATUS"/kresd-instances[name='%s']", inst_name);
+ kresd_get_status(inst_name, &status);
+ lyd_new_path(*parent, NULL, xpath, inst_name, 0, 0);
+
+ sprintf(&xpath, XPATH_STATUS"/kresd-instances[name='%s']/status", inst_name);
+ lyd_new_path(*parent, NULL, xpath, status, 0, 0);
+
+ free(status);
+ }
+ sr_free_values(vals, val_count);
+}
+
+int resolver_start()
+{
+ int ret = SR_ERR_OK;
+ struct server_config cfg = the_worker->engine->watcher.config;
+
+ kresd_instances_start(UNIT_START);
+
+ if (cfg.auto_cache_gc)
+ cache_gc_ctl(UNIT_START);
+
+ return ret;
+}
+
+int set_tst_secret(const char *new_secret)
+{
+ int ret = 0;
+ sr_conn_ctx_t *connection = NULL;
+ sr_session_ctx_t *session = NULL;
+
+ if (!ret) ret = sr_connect(0, &connection);
+ if (!ret) ret = sr_session_start(connection, SR_DS_RUNNING, &session);
+ if (!ret) ret = sr_set_item_str(session, XPATH_TST_SECRET, new_secret, NULL, 0);
+ if (!ret) ret = sr_validate(session, YM_COMMON, 0);
+ if (!ret) ret = sr_apply_changes(session, 0, 0);
+ if (ret)
+ kr_log_error(
+ "[sysrepo] failed to set '%s', %s\n",
+ XPATH_TST_SECRET, sr_strerror(ret));
+
+ sr_disconnect(connection);
+
+ return ret;
+}
+
+/* STATE DATA CALLBACKS */
+
+static int server_status_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath,
+const char *request_xpath, uint32_t request_id, struct lyd_node **parent, void *private_data)
+{
+ char *cache_gc_status;
+
+ /* kresd instances status */
+ kresd_instances_status(parent);
+
+ /* Cache Garbage Collector status */
+ cache_gc_get_status(&cache_gc_status);
+ lyd_new_path(*parent, NULL, XPATH_STATUS"/cache-gc", cache_gc_status, 0, 0);
+
+ free(cache_gc_status);
+ return SR_ERR_OK;
+}
+
+static int server_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath,
+const char *request_xpath, uint32_t request_id, struct lyd_node **parent, void *private_data)
+{
+ assert(parent!=NULL);
+
+ char str[128];
+ struct server_config cfg = the_worker->engine->watcher.config;
+ lyd_new_path(*parent, NULL, XPATH_VERSION, PACKAGE_VERSION, 0, 0);
+
+ sprintf(str, "%s", cfg.auto_start ? "true" : "false");
+ lyd_new_path(*parent, NULL, XPATH_SERVER"/"YM_KRES":auto-start", str, 0, 0);
+
+ sprintf(str, "%s", cfg.auto_cache_gc ? "true" : "false");
+ lyd_new_path(*parent, NULL, XPATH_SERVER"/"YM_KRES":auto-cache-gc", str, 0, 0);
+
+ sprintf(str, "%d", cfg.kresd_instances);
+ lyd_new_path(*parent, NULL, XPATH_SERVER"/"YM_KRES":kresd-instances", str, 0, 0);
+
+ return SR_ERR_OK;
+}
+
+static int instance_status_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath,
+const char *request_xpath, uint32_t request_id, struct lyd_node **parent, void *private_data)
+{
+ return SR_ERR_OK;
+}
+
+/* CONFIG DATA CALLBACKS */
+
+static int server_change_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath,
+sr_event_t event, uint32_t request_id, void *private_data)
+{
+ if(event == SR_EV_CHANGE)
+ {
+ /* validation actions*/
+ }
+ else if (event == SR_EV_DONE)
+ {
+ int err = SR_ERR_OK;
+ sr_change_oper_t oper;
+ sr_val_t *old_value = NULL;
+ sr_val_t *new_value = NULL;
+ sr_change_iter_t *it = NULL;
+ struct server_config *cfg = &the_worker->engine->watcher.config;
+
+ err = sr_get_changes_iter(session, XPATH_SERVER "/*/.", &it);
+ if (err != SR_ERR_OK) goto cleanup;
+
+ while ((sr_get_change_next(session, it, &oper, &old_value, &new_value)) == SR_ERR_OK) {
+
+ const char *leaf = remove_substr(new_value->xpath, XPATH_SERVER"/cznic-resolver-knot:");
+ printf("%s\n", new_value->xpath);
+ if (!strcmp(leaf, "auto-start"))
+ cfg->auto_start = new_value->data.bool_val;
+ else if (!strcmp(leaf, "auto-cache-gc"))
+ cfg->auto_cache_gc = new_value->data.bool_val;
+ else if (!strcmp(leaf, "kresd-instances"))
+ cfg->kresd_instances = new_value->data.uint8_val;
+
+ sr_free_val(old_value);
+ sr_free_val(new_value);
+ }
+
+ cleanup:
+ sr_free_change_iter(it);
+ }
+ else if(event == SR_EV_ABORT)
+ {
+ /* abortion actions */
+ }
+
+ return SR_ERR_OK;
+}
+
+static int tls_sticket_secret_change_cb(sr_session_ctx_t *session,
+const char *module_name, const char *xpath, sr_event_t event,
+uint32_t request_id, void *private_data)
+{
+ if(event == SR_EV_CHANGE)
+ {
+ /* validation actions*/
+ }
+ else if (event == SR_EV_DONE)
+ {
+ int ret = 0;
+ uv_loop_t *loop = the_worker->loop;
+ ret = tst_secret_timer_init(loop);
+ if (ret){
+ kr_log_error("[sysrepo] failed to init tls session ticket secret");
+ }
+ }
+ else if(event == SR_EV_ABORT)
+ {
+ /* abortion actions */
+ }
+
+ return SR_ERR_OK;
+}
+
+static int cache_prefill_change_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath,
+sr_event_t event, uint32_t request_id, void *private_data)
+{
+ if(event == SR_EV_CHANGE)
+ {
+ /* validation actions*/
+ }
+ else if (event == SR_EV_DONE)
+ {
+ int err = SR_ERR_OK;
+ sr_change_oper_t oper;
+ sr_val_t *old_value = NULL;
+ sr_val_t *new_value = NULL;
+ sr_change_iter_t *it = NULL;
+
+ struct server_config *cfg = &the_worker->engine->watcher.config;
+
+ err = sr_get_changes_iter(session, XPATH_CACHE_PREFILL "/*", &it);
+ if (err != SR_ERR_OK) goto cleanup;
+
+ lua_State *L = the_worker->engine->L;
+ engine_cmd(L, "modules.load('prefill')",false);
+ lua_getglobal(L, "prefill.config");
+ lua_newtable(L);
+
+ while ((sr_get_change_next(session, it, &oper, &old_value, &new_value)) == SR_ERR_OK) {
+
+ printf("%s\n", new_value->xpath);
+ char *leaf = strrchr(new_value->xpath, '/');
+
+ if (leaf && !strcmp(leaf, "/origin")) {
+
+ }
+ if (leaf && !strcmp(leaf, "/url")) {
+ lua_pushstring(L, new_value->data.string_val);
+ lua_setfield(L, -2, "url");
+ }
+
+ if (leaf && !strcmp(leaf, "/ca-file")){
+ lua_pushstring(L, new_value->data.string_val);
+ lua_setfield(L, -2, "ca_file");
+ }
+ if (leaf && !strcmp(leaf, "/refresh-interval")) {
+ lua_pushnumber(L, new_value->data.uint32_val);
+ lua_setfield(L, -2, "interval");
+ }
+
+ sr_free_val(old_value);
+ sr_free_val(new_value);
+ }
+
+ lua_setglobal(L, ".");
+ engine_pcall(L, 1);
+
+ cleanup:
+ sr_free_change_iter(it);
+ }
+ else if(event == SR_EV_ABORT)
+ {
+ /* abortion actions */
+ }
+ return SR_ERR_OK;
+}
+
+/* RPC CALLBACKS */
+
+/* Callback for kresd instances controll */
+static int rpc_instance_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ const char *leaf = remove_substr(path, XPATH_SERVER"/");
+
+
+ return 0;
+}
+
+static int rpc_resolver_start_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ int ret = SR_ERR_OK;
+ struct server_config cfg = the_worker->engine->watcher.config;
+
+ kresd_instances_start(UNIT_START);
+
+ if (cfg.auto_cache_gc)
+ cache_gc_ctl(UNIT_START);
+
+ return ret;
+}
+
+static int rpc_resolver_stop_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ int ret = SR_ERR_OK;
+ struct server_config cfg = the_worker->engine->watcher.config;
+
+ kresd_instances_start(UNIT_STOP);
+
+ cache_gc_ctl(UNIT_STOP);
+
+ return ret;
+}
+
+static int rpc_resolver_restart_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ int ret = SR_ERR_OK;
+ struct server_config cfg = the_worker->engine->watcher.config;
+
+ kresd_instances_start(UNIT_RESTART);
+
+ if (cfg.auto_cache_gc)
+ cache_gc_ctl(UNIT_RESTART);
+
+ return ret;
+}
+
+static int rpc_cache_gc_start_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ int ret = cache_gc_ctl(UNIT_START);
+
+ return 0;
+}
+
+static int rpc_cache_gc_stop_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ int ret = cache_gc_ctl(UNIT_STOP);
+
+ return 0;
+}
+
+static int rpc_cache_gc_restart_cb(sr_session_ctx_t *session, const char *path,
+const sr_val_t *input, const size_t input_cnt, sr_event_t event,
+uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data)
+{
+ int ret = cache_gc_ctl(UNIT_RESTART);
+
+ return 0;
+}
+
+int sysrepo_subscr_register(sr_session_ctx_t *session, sr_subscription_ctx_t **subscription)
+{
+ int err = SR_ERR_OK;
+
+ /* CONFIG CHANGES */
+
+ err = sr_module_change_subscribe(session, YM_COMMON, XPATH_SERVER,
+ server_change_cb, NULL, 0, SR_SUBSCR_NO_THREAD|SR_SUBSCR_ENABLED|SR_SUBSCR_DONE_ONLY, subscription);
+ if (err != SR_ERR_OK) return err;
+
+ err = sr_module_change_subscribe(session, YM_COMMON, XPATH_TST_SECRET,
+ tls_sticket_secret_change_cb, NULL, 0, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE|SR_SUBSCR_DONE_ONLY, subscription);
+ if (err != SR_ERR_OK) return err;
+
+ err = sr_module_change_subscribe(session, YM_COMMON, XPATH_CACHE_PREFILL,
+ cache_prefill_change_cb, NULL, 0, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE|SR_SUBSCR_ENABLED|SR_SUBSCR_DONE_ONLY, subscription);
+ if (err != SR_ERR_OK) return err;
+
+ /* RPC OPERATIONS */
+
+ err = sr_rpc_subscribe(session, XPATH_RPC_BASE":start", rpc_resolver_start_cb, NULL, 0,
+ SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != 0) return err;
+
+ err = sr_rpc_subscribe(session, XPATH_RPC_BASE":stop", rpc_resolver_stop_cb, NULL, 0,
+ SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != 0) return err;
+
+ err = sr_rpc_subscribe(session, XPATH_RPC_BASE":restart", rpc_resolver_restart_cb, NULL, 0,
+ SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != 0) return err;
+
+ err = sr_rpc_subscribe(session, XPATH_GC"/start", rpc_cache_gc_start_cb, NULL, 0,
+ SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != 0) return err;
+
+ err = sr_rpc_subscribe(session, XPATH_GC"/stop", rpc_cache_gc_stop_cb, NULL, 0,
+ SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != 0) return err;
+
+ err = sr_rpc_subscribe(session, XPATH_GC"/restart", rpc_cache_gc_restart_cb, NULL, 0,
+ SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != 0) return err;
+
+ /* STATE DATA */
+
+ err = sr_oper_get_items_subscribe(session, YM_COMMON, XPATH_SERVER, server_cb, NULL, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != SR_ERR_OK) return err;
+
+ err = sr_oper_get_items_subscribe(session, YM_COMMON, XPATH_STATUS, server_status_cb, NULL, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != SR_ERR_OK) return err;
+
+ err = sr_oper_get_items_subscribe(session, YM_COMMON, XPATH_INSTANCES, instance_status_cb, NULL, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription);
+ if (err != SR_ERR_OK) return err;
+
+ return err;
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+int sysrepo_subscr_register(sr_session_ctx_t *session, sr_subscription_ctx_t **subscription);
+
+int set_tst_secret(const char *secret);
+
+int resolver_start();
--- /dev/null
+/*
+ * Copyright (C) 2016 American Civil Liberties Union (ACLU)
+ * 2016-2018 CZ.NIC, z.s.p.o
+ *
+ * Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
+ * Ondřej Surý <ondrej@sury.org>
+ *
+ * This program is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option)
+ * any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <gnutls/abstract.h>
+#include <gnutls/crypto.h>
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
+#include <uv.h>
+
+#include <assert.h>
+#include <errno.h>
+#include <stdlib.h>
+
+#include "contrib/ucw/lib.h"
+#include "contrib/base64.h"
+#include "io.h"
+#include "tls.h"
+#include "worker.h"
+#include "session.h"
+
+#define EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE (60*60*24*7)
+#define GNUTLS_PIN_MIN_VERSION 0x030400
+
+/** @internal Debugging facility. */
+#ifdef DEBUG
+#define DEBUG_MSG(...) kr_log_verbose("[tls] " __VA_ARGS__)
+#else
+#define DEBUG_MSG(...)
+#endif
+
+struct async_write_ctx {
+ uv_write_t write_req;
+ struct tls_common_ctx *t;
+ char buf[];
+};
+
+static char const server_logstring[] = "tls";
+static char const client_logstring[] = "tls_client";
+
+static int client_verify_certificate(gnutls_session_t tls_session);
+
+/**
+ * Set mandatory security settings from
+ * https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles-11#section-9
+ * Performance optimizations are not implemented at the moment.
+ */
+static int kres_gnutls_set_priority(gnutls_session_t session) {
+ static const char * const priorities =
+ "NORMAL:" /* GnuTLS defaults */
+ "-VERS-TLS1.0:-VERS-TLS1.1:" /* TLS 1.2 and higher */
+ /* Some distros by default allow features that are considered
+ * too insecure nowadays, so let's disable them explicitly. */
+ "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
+ const char *errpos = NULL;
+ int err = gnutls_priority_set_direct(session, priorities, &errpos);
+ if (err != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
+ priorities, errpos - priorities, errpos, gnutls_strerror_name(err), err);
+ }
+ return err;
+}
+
+static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
+{
+ struct tls_common_ctx *t = (struct tls_common_ctx *)h;
+ assert(t != NULL);
+
+ ssize_t avail = t->nread - t->consumed;
+ DEBUG_MSG("[%s] pull wanted: %zu available: %zu\n",
+ t->client_side ? "tls_client" : "tls", len, avail);
+ if (t->nread <= t->consumed) {
+ errno = EAGAIN;
+ return -1;
+ }
+
+ ssize_t transfer = MIN(avail, len);
+ memcpy(buf, t->buf + t->consumed, transfer);
+ t->consumed += transfer;
+ return transfer;
+}
+
+static void on_write_complete(uv_write_t *req, int status)
+{
+ assert(req->data != NULL);
+ struct async_write_ctx *async_ctx = (struct async_write_ctx *)req->data;
+ struct tls_common_ctx *t = async_ctx->t;
+ assert(t->write_queue_size);
+ t->write_queue_size -= 1;
+ free(req->data);
+}
+
+static bool stream_queue_is_empty(struct tls_common_ctx *t)
+{
+ return (t->write_queue_size == 0);
+}
+
+static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt)
+{
+ struct tls_common_ctx *t = (struct tls_common_ctx *)h;
+
+ if (t == NULL) {
+ errno = EFAULT;
+ return -1;
+ }
+
+ if (iovcnt == 0) {
+ return 0;
+ }
+
+ assert(t->session);
+ uv_stream_t *handle = (uv_stream_t *)session_get_handle(t->session);
+ assert(handle && handle->type == UV_TCP);
+
+ /*
+ * This is a little bit complicated. There are two different writes:
+ * 1. Immediate, these don't need to own the buffered data and return immediately
+ * 2. Asynchronous, these need to own the buffers until the write completes
+ * In order to avoid copying the buffer, an immediate write is tried first if possible.
+ * If it isn't possible to write the data without queueing, an asynchronous write
+ * is created (with copied buffered data).
+ */
+
+ size_t total_len = 0;
+ uv_buf_t uv_buf[iovcnt];
+ for (int i = 0; i < iovcnt; ++i) {
+ uv_buf[i].base = iov[i].iov_base;
+ uv_buf[i].len = iov[i].iov_len;
+ total_len += iov[i].iov_len;
+ }
+
+ /* Try to perform the immediate write first to avoid copy */
+ int ret = 0;
+ if (stream_queue_is_empty(t)) {
+ ret = uv_try_write(handle, uv_buf, iovcnt);
+ DEBUG_MSG("[%s] push %zu <%p> = %d\n",
+ t->client_side ? "tls_client" : "tls", total_len, h, ret);
+ /* from libuv documentation -
+ uv_try_write will return either:
+ > 0: number of bytes written (can be less than the supplied buffer size).
+ < 0: negative error code (UV_EAGAIN is returned if no data can be sent immediately).
+ */
+ if (ret == total_len) {
+ /* All the data were buffered by libuv.
+ * Return. */
+ return ret;
+ }
+
+ if (ret < 0 && ret != UV_EAGAIN) {
+ /* uv_try_write() has returned error code other then UV_EAGAIN.
+ * Return. */
+ ret = -1;
+ errno = EIO;
+ return ret;
+ }
+ /* Since we are here expression below is true
+ * (ret != total_len) && (ret >= 0 || ret == UV_EAGAIN)
+ * or the same
+ * (ret != total_len && ret >= 0) || (ret != total_len && ret == UV_EAGAIN)
+ * i.e. either occurs partial write or UV_EAGAIN.
+ * Proceed and copy data amount to owned memory and perform async write.
+ */
+ if (ret == UV_EAGAIN) {
+ /* No data were buffered, so we must buffer all the data. */
+ ret = 0;
+ }
+ }
+
+ /* Fallback when the queue is full, and it's not possible to do an immediate write */
+ char *p = malloc(sizeof(struct async_write_ctx) + total_len - ret);
+ if (p != NULL) {
+ struct async_write_ctx *async_ctx = (struct async_write_ctx *)p;
+ /* Save pointer to session tls context */
+ async_ctx->t = t;
+ char *buf = async_ctx->buf;
+ /* Skip data written in the partial write */
+ size_t to_skip = ret;
+ /* Copy the buffer into owned memory */
+ size_t off = 0;
+ for (int i = 0; i < iovcnt; ++i) {
+ if (to_skip > 0) {
+ /* Ignore current buffer if it's all skipped */
+ if (to_skip >= uv_buf[i].len) {
+ to_skip -= uv_buf[i].len;
+ continue;
+ }
+ /* Skip only part of the buffer */
+ uv_buf[i].base += to_skip;
+ uv_buf[i].len -= to_skip;
+ to_skip = 0;
+ }
+ memcpy(buf + off, uv_buf[i].base, uv_buf[i].len);
+ off += uv_buf[i].len;
+ }
+ uv_buf[0].base = buf;
+ uv_buf[0].len = off;
+
+ /* Create an asynchronous write request */
+ uv_write_t *write_req = &async_ctx->write_req;
+ memset(write_req, 0, sizeof(uv_write_t));
+ write_req->data = p;
+
+ /* Perform an asynchronous write with a callback */
+ if (uv_write(write_req, handle, uv_buf, 1, on_write_complete) == 0) {
+ ret = total_len;
+ t->write_queue_size += 1;
+ } else {
+ free(p);
+ errno = EIO;
+ ret = -1;
+ }
+ } else {
+ errno = ENOMEM;
+ ret = -1;
+ }
+
+ DEBUG_MSG("[%s] queued %zu <%p> = %d\n",
+ t->client_side ? "tls_client" : "tls", total_len, h, ret);
+
+ return ret;
+}
+
+/** Perform TLS handshake and handle error codes according to the documentation.
+ * See See https://gnutls.org/manual/html_node/TLS-handshake.html#TLS-handshake
+ * The function returns kr_ok() or success or non fatal error, kr_error(EAGAIN) on blocking, or kr_error(EIO) on fatal error.
+ */
+static int tls_handshake(struct tls_common_ctx *ctx, tls_handshake_cb handshake_cb) {
+ struct session *session = ctx->session;
+ const char *logstring = ctx->client_side ? client_logstring : server_logstring;
+
+ int err = gnutls_handshake(ctx->tls_session);
+ if (err == GNUTLS_E_SUCCESS) {
+ /* Handshake finished, return success */
+ ctx->handshake_state = TLS_HS_DONE;
+ struct sockaddr *peer = session_get_peer(session);
+ kr_log_verbose("[%s] TLS handshake with %s has completed\n",
+ logstring, kr_straddr(peer));
+ if (handshake_cb) {
+ if (handshake_cb(session, 0) != kr_ok()) {
+ return kr_error(EIO);
+ }
+ }
+ } else if (err == GNUTLS_E_AGAIN) {
+ return kr_error(EAGAIN);
+ } else if (gnutls_error_is_fatal(err)) {
+ /* Fatal errors, return error as it's not recoverable */
+ kr_log_verbose("[%s] gnutls_handshake failed: %s (%d)\n",
+ logstring,
+ gnutls_strerror_name(err), err);
+ if (handshake_cb) {
+ handshake_cb(session, -1);
+ }
+ return kr_error(EIO);
+ } else if (err == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+ /* Handle warning when in verbose mode */
+ const char *alert_name = gnutls_alert_get_name(gnutls_alert_get(ctx->tls_session));
+ if (alert_name != NULL) {
+ struct sockaddr *peer = session_get_peer(session);
+ kr_log_verbose("[%s] TLS alert from %s received: %s\n",
+ logstring, kr_straddr(peer), alert_name);
+ }
+ }
+ return kr_ok();
+}
+
+
+struct tls_ctx_t *tls_new(struct worker_ctx *worker)
+{
+ assert(worker != NULL);
+ assert(worker->engine != NULL);
+
+ struct network *net = &worker->engine->net;
+ if (!net->tls_credentials) {
+ net->tls_credentials = tls_get_ephemeral_credentials(worker->engine);
+ if (!net->tls_credentials) {
+ kr_log_error("[tls] X.509 credentials are missing, and ephemeral credentials failed; no TLS\n");
+ return NULL;
+ }
+ kr_log_info("[tls] Using ephemeral TLS credentials:\n");
+ tls_credentials_log_pins(net->tls_credentials);
+ }
+
+ time_t now = time(NULL);
+ if (net->tls_credentials->valid_until != GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION) {
+ if (net->tls_credentials->ephemeral_servicename) {
+ /* ephemeral cert: refresh if due to expire within a week */
+ if (now >= net->tls_credentials->valid_until - EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE) {
+ struct tls_credentials *newcreds = tls_get_ephemeral_credentials(worker->engine);
+ if (newcreds) {
+ tls_credentials_release(net->tls_credentials);
+ net->tls_credentials = newcreds;
+ kr_log_info("[tls] Renewed expiring ephemeral X.509 cert\n");
+ } else {
+ kr_log_error("[tls] Failed to renew expiring ephemeral X.509 cert, using existing one\n");
+ }
+ }
+ } else {
+ /* non-ephemeral cert: warn once when certificate expires */
+ if (now >= net->tls_credentials->valid_until) {
+ kr_log_error("[tls] X.509 certificate has expired!\n");
+ net->tls_credentials->valid_until = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
+ }
+ }
+ }
+
+ struct tls_ctx_t *tls = calloc(1, sizeof(struct tls_ctx_t));
+ if (tls == NULL) {
+ kr_log_error("[tls] failed to allocate TLS context\n");
+ return NULL;
+ }
+
+ int err = gnutls_init(&tls->c.tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
+ if (err != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] gnutls_init(): %s (%d)\n", gnutls_strerror_name(err), err);
+ tls_free(tls);
+ return NULL;
+ }
+ tls->credentials = tls_credentials_reserve(net->tls_credentials);
+ err = gnutls_credentials_set(tls->c.tls_session, GNUTLS_CRD_CERTIFICATE,
+ tls->credentials->credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] gnutls_credentials_set(): %s (%d)\n", gnutls_strerror_name(err), err);
+ tls_free(tls);
+ return NULL;
+ }
+ if (kres_gnutls_set_priority(tls->c.tls_session) != GNUTLS_E_SUCCESS) {
+ tls_free(tls);
+ return NULL;
+ }
+
+ tls->c.worker = worker;
+ tls->c.client_side = false;
+
+ gnutls_transport_set_pull_function(tls->c.tls_session, kres_gnutls_pull);
+ gnutls_transport_set_vec_push_function(tls->c.tls_session, kres_gnutls_vec_push);
+ gnutls_transport_set_ptr(tls->c.tls_session, tls);
+
+ if (net->tls_session_ticket_ctx) {
+ tls_session_ticket_enable(net->tls_session_ticket_ctx,
+ tls->c.tls_session);
+ }
+
+ return tls;
+}
+
+void tls_close(struct tls_common_ctx *ctx)
+{
+ if (ctx == NULL || ctx->tls_session == NULL) {
+ return;
+ }
+
+ assert(ctx->session);
+
+ if (ctx->handshake_state == TLS_HS_DONE) {
+ const struct sockaddr *peer = session_get_peer(ctx->session);
+ kr_log_verbose("[%s] closing tls connection to `%s`\n",
+ ctx->client_side ? "tls_client" : "tls",
+ kr_straddr(peer));
+ ctx->handshake_state = TLS_HS_CLOSING;
+ gnutls_bye(ctx->tls_session, GNUTLS_SHUT_RDWR);
+ }
+}
+
+void tls_free(struct tls_ctx_t *tls)
+{
+ if (!tls) {
+ return;
+ }
+
+ if (tls->c.tls_session) {
+ /* Don't terminate TLS connection, just tear it down */
+ gnutls_deinit(tls->c.tls_session);
+ tls->c.tls_session = NULL;
+ }
+
+ tls_credentials_release(tls->credentials);
+ free(tls);
+}
+
+int tls_write(uv_write_t *req, uv_handle_t *handle, knot_pkt_t *pkt, uv_write_cb cb)
+{
+ if (!pkt || !handle || !handle->data) {
+ return kr_error(EINVAL);
+ }
+
+ struct session *s = handle->data;
+ struct tls_common_ctx *tls_ctx = session_tls_get_common_ctx(s);
+
+ assert (tls_ctx);
+ assert (session_flags(s)->outgoing == tls_ctx->client_side);
+
+ const uint16_t pkt_size = htons(pkt->size);
+ const char *logstring = tls_ctx->client_side ? client_logstring : server_logstring;
+ gnutls_session_t tls_session = tls_ctx->tls_session;
+
+ gnutls_record_cork(tls_session);
+ ssize_t count = 0;
+ if ((count = gnutls_record_send(tls_session, &pkt_size, sizeof(pkt_size)) < 0) ||
+ (count = gnutls_record_send(tls_session, pkt->wire, pkt->size) < 0)) {
+ kr_log_error("[%s] gnutls_record_send failed: %s (%zd)\n",
+ logstring, gnutls_strerror_name(count), count);
+ return kr_error(EIO);
+ }
+
+ const ssize_t submitted = sizeof(pkt_size) + pkt->size;
+
+ int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT);
+ if (ret < 0) {
+ if (!gnutls_error_is_fatal(ret)) {
+ return kr_error(EAGAIN);
+ } else {
+ kr_log_error("[%s] gnutls_record_uncork failed: %s (%d)\n",
+ logstring, gnutls_strerror_name(ret), ret);
+ return kr_error(EIO);
+ }
+ }
+
+ if (ret != submitted) {
+ kr_log_error("[%s] gnutls_record_uncork didn't send all data (%d of %zd)\n",
+ logstring, ret, submitted);
+ return kr_error(EIO);
+ }
+
+ /* The data is now accepted in gnutls internal buffers, the message can be treated as sent */
+ req->handle = (uv_stream_t *)handle;
+ cb(req, 0);
+
+ return kr_ok();
+}
+
+ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nread)
+{
+ struct tls_common_ctx *tls_p = session_tls_get_common_ctx(s);
+ if (!tls_p) {
+ return kr_error(ENOSYS);
+ }
+
+ assert(tls_p->session == s);
+ const bool ok = tls_p->recv_buf == buf && nread <= sizeof(tls_p->recv_buf);
+ if (!ok) {
+ assert(false);
+ /* don't risk overflowing the buffer if we have a mistake somewhere */
+ return kr_error(EINVAL);
+ }
+
+ const char *logstring = tls_p->client_side ? client_logstring : server_logstring;
+
+ tls_p->buf = buf;
+ tls_p->nread = nread >= 0 ? nread : 0;
+ tls_p->consumed = 0;
+
+ /* Ensure TLS handshake is performed before receiving data.
+ * See https://www.gnutls.org/manual/html_node/TLS-handshake.html */
+ while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) {
+ int err = tls_handshake(tls_p, tls_p->handshake_cb);
+ if (err == kr_error(EAGAIN)) {
+ return 0; /* Wait for more data */
+ } else if (err != kr_ok()) {
+ return err;
+ }
+ }
+
+ /* See https://gnutls.org/manual/html_node/Data-transfer-and-termination.html#Data-transfer-and-termination */
+ ssize_t submitted = 0;
+ uint8_t *wire_buf = session_wirebuf_get_free_start(s);
+ size_t wire_buf_size = session_wirebuf_get_free_size(s);
+ while (true) {
+ ssize_t count = gnutls_record_recv(tls_p->tls_session, wire_buf, wire_buf_size);
+ if (count == GNUTLS_E_AGAIN) {
+ if (tls_p->consumed == tls_p->nread) {
+ /* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
+ break; /* No more data available in this libuv buffer */
+ }
+ continue;
+ } else if (count == GNUTLS_E_INTERRUPTED) {
+ continue;
+ } else if (count == GNUTLS_E_REHANDSHAKE) {
+ /* See https://www.gnutls.org/manual/html_node/Re_002dauthentication.html */
+ struct sockaddr *peer = session_get_peer(s);
+ kr_log_verbose("[%s] TLS rehandshake with %s has started\n",
+ logstring, kr_straddr(peer));
+ tls_set_hs_state(tls_p, TLS_HS_IN_PROGRESS);
+ int err = kr_ok();
+ while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) {
+ err = tls_handshake(tls_p, tls_p->handshake_cb);
+ if (err == kr_error(EAGAIN)) {
+ break;
+ } else if (err != kr_ok()) {
+ return err;
+ }
+ }
+ if (err == kr_error(EAGAIN)) {
+ /* pull function is out of data */
+ break;
+ }
+ /* There are can be data available, check it. */
+ continue;
+ } else if (count < 0) {
+ kr_log_verbose("[%s] gnutls_record_recv failed: %s (%zd)\n",
+ logstring, gnutls_strerror_name(count), count);
+ return kr_error(EIO);
+ } else if (count == 0) {
+ break;
+ }
+ DEBUG_MSG("[%s] received %zd data\n", logstring, count);
+ wire_buf += count;
+ wire_buf_size -= count;
+ submitted += count;
+ if (wire_buf_size == 0 && tls_p->consumed != tls_p->nread) {
+ /* session buffer is full
+ * whereas not all the data were consumed */
+ return kr_error(ENOSPC);
+ }
+ }
+ /* Here all data must be consumed. */
+ if (tls_p->consumed != tls_p->nread) {
+ /* Something went wrong, better return error.
+ * This is most probably due to gnutls_record_recv() did not
+ * consume all available network data by calling kres_gnutls_pull().
+ * TODO assess the need for buffering of data amount.
+ */
+ return kr_error(ENOSPC);
+ }
+ return submitted;
+}
+
+#if TLS_CAN_USE_PINS
+/*
+ DNS-over-TLS Out of band key-pinned authentication profile uses the
+ same form of pins as HPKP:
+
+ e.g. pin-sha256="FHkyLhvI0n70E47cJlRTamTrnYVcsYdjUGbr79CfAVI="
+
+ DNS-over-TLS OOB key-pins: https://tools.ietf.org/html/rfc7858#appendix-A
+ HPKP pin reference: https://tools.ietf.org/html/rfc7469#appendix-A
+*/
+#define PINLEN ((((32) * 8 + 4)/6) + 3 + 1)
+
+/* Compute pin_sha256 for the certificate.
+ * It may be in raw format - just TLS_SHA256_RAW_LEN bytes without termination,
+ * or it may be a base64 0-terminated string requiring up to
+ * TLS_SHA256_BASE64_BUFLEN bytes.
+ * \return error code */
+static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len, bool raw)
+{
+ if (raw && outchar_len < TLS_SHA256_RAW_LEN) {
+ assert(false);
+ return kr_error(ENOSPC);
+ /* With !raw we have check inside base64_encode. */
+ }
+ gnutls_pubkey_t key;
+ int err = gnutls_pubkey_init(&key);
+ if (err != GNUTLS_E_SUCCESS) return err;
+
+ gnutls_datum_t datum = { .data = NULL, .size = 0 };
+ err = gnutls_pubkey_import_x509(key, crt, 0);
+ if (err != GNUTLS_E_SUCCESS) goto leave;
+
+ err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum);
+ if (err != GNUTLS_E_SUCCESS) goto leave;
+
+ char raw_pin[TLS_SHA256_RAW_LEN]; /* TMP buffer if raw == false */
+ err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size,
+ (raw ? outchar : raw_pin));
+ if (err != GNUTLS_E_SUCCESS || raw/*success*/)
+ goto leave;
+ /* Convert to non-raw. */
+ err = base64_encode((uint8_t *)raw_pin, sizeof(raw_pin),
+ (uint8_t *)outchar, outchar_len);
+ if (err >= 0 && err < outchar_len) {
+ err = GNUTLS_E_SUCCESS;
+ outchar[err] = '\0'; /* base64_encode() doesn't do it */
+ } else if (err >= 0) {
+ assert(false);
+ err = kr_error(ENOSPC); /* base64 fits but '\0' doesn't */
+ outchar[outchar_len - 1] = '\0';
+ }
+leave:
+ gnutls_free(datum.data);
+ gnutls_pubkey_deinit(key);
+ return err;
+}
+
+void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
+{
+ for (int index = 0;; index++) {
+ gnutls_x509_crt_t *certs = NULL;
+ unsigned int cert_count = 0;
+ int err = gnutls_certificate_get_x509_crt(tls_credentials->credentials,
+ index, &certs, &cert_count);
+ if (err != GNUTLS_E_SUCCESS) {
+ if (err != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
+ kr_log_error("[tls] could not get X.509 certificates (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ return;
+ }
+
+ for (int i = 0; i < cert_count; i++) {
+ char pin[TLS_SHA256_BASE64_BUFLEN] = { 0 };
+ err = get_oob_key_pin(certs[i], pin, sizeof(pin), false);
+ if (err != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n",
+ i, err, gnutls_strerror_name(err));
+ } else {
+ kr_log_info("[tls] RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n",
+ i, pin);
+ }
+ gnutls_x509_crt_deinit(certs[i]);
+ }
+ gnutls_free(certs);
+ }
+}
+#else
+void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
+{
+ kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin; GnuTLS 3.4.0+ required\n");
+}
+#endif
+
+static int str_replace(char **where_ptr, const char *with)
+{
+ char *copy = with ? strdup(with) : NULL;
+ if (with && !copy) {
+ return kr_error(ENOMEM);
+ }
+
+ free(*where_ptr);
+ *where_ptr = copy;
+ return kr_ok();
+}
+
+static time_t _get_end_entity_expiration(gnutls_certificate_credentials_t creds)
+{
+ gnutls_datum_t data;
+ gnutls_x509_crt_t cert = NULL;
+ int err;
+ time_t ret = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION;
+
+ if ((err = gnutls_certificate_get_crt_raw(creds, 0, 0, &data)) != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] failed to get cert to check expiration: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ goto done;
+ }
+ if ((err = gnutls_x509_crt_init(&cert)) != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] failed to initialize cert: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ goto done;
+ }
+ if ((err = gnutls_x509_crt_import(cert, &data, GNUTLS_X509_FMT_DER)) != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] failed to construct cert while checking expiration: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ goto done;
+ }
+
+ ret = gnutls_x509_crt_get_expiration_time (cert);
+ done:
+ /* do not free data; g_c_get_crt_raw() says to treat it as
+ * constant. */
+ gnutls_x509_crt_deinit(cert);
+ return ret;
+}
+
+int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key)
+{
+ if (!net) {
+ return kr_error(EINVAL);
+ }
+
+ struct tls_credentials *tls_credentials = calloc(1, sizeof(*tls_credentials));
+ if (tls_credentials == NULL) {
+ return kr_error(ENOMEM);
+ }
+
+ int err = 0;
+ if ((err = gnutls_certificate_allocate_credentials(&tls_credentials->credentials)) != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls] gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ tls_credentials_free(tls_credentials);
+ return kr_error(ENOMEM);
+ }
+ if ((err = gnutls_certificate_set_x509_system_trust(tls_credentials->credentials)) < 0) {
+ if (err != GNUTLS_E_UNIMPLEMENTED_FEATURE) {
+ kr_log_error("[tls] warning: gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ tls_credentials_free(tls_credentials);
+ return err;
+ }
+ }
+
+ if ((str_replace(&tls_credentials->tls_cert, tls_cert) != 0) ||
+ (str_replace(&tls_credentials->tls_key, tls_key) != 0)) {
+ tls_credentials_free(tls_credentials);
+ return kr_error(ENOMEM);
+ }
+
+ if ((err = gnutls_certificate_set_x509_key_file(tls_credentials->credentials,
+ tls_cert, tls_key, GNUTLS_X509_FMT_PEM)) != GNUTLS_E_SUCCESS) {
+ tls_credentials_free(tls_credentials);
+ kr_log_error("[tls] gnutls_certificate_set_x509_key_file(%s,%s) failed: %d (%s)\n",
+ tls_cert, tls_key, err, gnutls_strerror_name(err));
+ return kr_error(EINVAL);
+ }
+ /* record the expiration date: */
+ tls_credentials->valid_until = _get_end_entity_expiration(tls_credentials->credentials);
+
+ /* Exchange the x509 credentials */
+ struct tls_credentials *old_credentials = net->tls_credentials;
+
+ /* Start using the new x509_credentials */
+ net->tls_credentials = tls_credentials;
+ tls_credentials_log_pins(net->tls_credentials);
+
+ if (old_credentials) {
+ err = tls_credentials_release(old_credentials);
+ if (err != kr_error(EBUSY)) {
+ return err;
+ }
+ }
+
+ return kr_ok();
+}
+
+struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials) {
+ if (!tls_credentials) {
+ return NULL;
+ }
+ tls_credentials->count++;
+ return tls_credentials;
+}
+
+int tls_credentials_release(struct tls_credentials *tls_credentials) {
+ if (!tls_credentials) {
+ return kr_error(EINVAL);
+ }
+ if (--tls_credentials->count < 0) {
+ tls_credentials_free(tls_credentials);
+ } else {
+ return kr_error(EBUSY);
+ }
+ return kr_ok();
+}
+
+void tls_credentials_free(struct tls_credentials *tls_credentials) {
+ if (!tls_credentials) {
+ return;
+ }
+
+ if (tls_credentials->credentials) {
+ gnutls_certificate_free_credentials(tls_credentials->credentials);
+ }
+ if (tls_credentials->tls_cert) {
+ free(tls_credentials->tls_cert);
+ }
+ if (tls_credentials->tls_key) {
+ free(tls_credentials->tls_key);
+ }
+ if (tls_credentials->ephemeral_servicename) {
+ free(tls_credentials->ephemeral_servicename);
+ }
+ free(tls_credentials);
+}
+
+void tls_client_param_unref(tls_client_param_t *entry)
+{
+ if (!entry) return;
+ assert(entry->refs); /* Well, we'd only leak memory. */
+ --(entry->refs);
+ if (entry->refs) return;
+
+ DEBUG_MSG("freeing TLS parameters %p\n", (void *)entry);
+
+ for (int i = 0; i < entry->ca_files.len; ++i) {
+ free_const(entry->ca_files.at[i]);
+ }
+ array_clear(entry->ca_files);
+
+ free_const(entry->hostname);
+
+ for (int i = 0; i < entry->pins.len; ++i) {
+ free_const(entry->pins.at[i]);
+ }
+ array_clear(entry->pins);
+
+ if (entry->credentials) {
+ gnutls_certificate_free_credentials(entry->credentials);
+ }
+
+ if (entry->session_data.data) {
+ gnutls_free(entry->session_data.data);
+ }
+
+ free(entry);
+}
+static int param_free(void **param, void *null)
+{
+ assert(param && *param);
+ tls_client_param_unref(*param);
+ return 0;
+}
+void tls_client_params_free(tls_client_params_t *params)
+{
+ if (!params) return;
+ trie_apply(params, param_free, NULL);
+ trie_free(params);
+}
+
+tls_client_param_t * tls_client_param_new()
+{
+ tls_client_param_t *e = calloc(1, sizeof(*e));
+ if (!e) {
+ assert(!ENOMEM);
+ return NULL;
+ }
+ /* Note: those array_t don't need further initialization. */
+ e->refs = 1;
+ int ret = gnutls_certificate_allocate_credentials(&e->credentials);
+ if (ret != GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls_client] error: gnutls_certificate_allocate_credentials() fails (%s)\n",
+ gnutls_strerror_name(ret));
+ free(e);
+ return NULL;
+ }
+ gnutls_certificate_set_verify_function(e->credentials, client_verify_certificate);
+ return e;
+}
+
+/**
+ * Convert an IP address and port number to binary key.
+ *
+ * \precond buffer \param key must have sufficient size
+ * \param addr[in]
+ * \param len[out] output length
+ * \param key[out] output buffer
+ */
+static bool construct_key(const union inaddr *addr, uint32_t *len, char *key)
+{
+ switch (addr->ip.sa_family) {
+ case AF_INET:
+ memcpy(key, &addr->ip4.sin_port, sizeof(addr->ip4.sin_port));
+ memcpy(key + sizeof(addr->ip4.sin_port), &addr->ip4.sin_addr,
+ sizeof(addr->ip4.sin_addr));
+ *len = sizeof(addr->ip4.sin_port) + sizeof(addr->ip4.sin_addr);
+ return true;
+ case AF_INET6:
+ memcpy(key, &addr->ip6.sin6_port, sizeof(addr->ip6.sin6_port));
+ memcpy(key + sizeof(addr->ip6.sin6_port), &addr->ip6.sin6_addr,
+ sizeof(addr->ip6.sin6_addr));
+ *len = sizeof(addr->ip6.sin6_port) + sizeof(addr->ip6.sin6_addr);
+ return true;
+ default:
+ assert(!EINVAL);
+ return false;
+ }
+}
+tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params,
+ const struct sockaddr *addr, bool do_insert)
+{
+ assert(params && addr);
+ /* We accept NULL for empty map; ensure the map exists if needed. */
+ if (!*params) {
+ if (!do_insert) return NULL;
+ *params = trie_create(NULL);
+ if (!*params) {
+ assert(!ENOMEM);
+ return NULL;
+ }
+ }
+ /* Construct the key. */
+ const union inaddr *ia = (const union inaddr *)addr;
+ char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
+ uint32_t len;
+ if (!construct_key(ia, &len, key))
+ return NULL;
+ /* Get the entry. */
+ return (tls_client_param_t **)
+ (do_insert ? trie_get_ins : trie_get_try)(*params, key, len);
+}
+
+int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr)
+{
+ const union inaddr *ia = (const union inaddr *)addr;
+ char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
+ uint32_t len;
+ if (!construct_key(ia, &len, key))
+ return kr_error(EINVAL);
+ trie_val_t param_ptr;
+ int ret = trie_del(params, key, len, ¶m_ptr);
+ if (ret)
+ return kr_error(ret);
+ tls_client_param_unref(param_ptr);
+ return kr_ok();
+}
+
+/**
+ * Verify that at least one certificate in the certificate chain matches
+ * at least one certificate pin in the non-empty params->pins array.
+ * \returns GNUTLS_E_SUCCESS if pin matches, any other value is an error
+ */
+static int client_verify_pin(const unsigned int cert_list_size,
+ const gnutls_datum_t *cert_list,
+ tls_client_param_t *params)
+{
+ assert(params->pins.len > 0);
+#if TLS_CAN_USE_PINS
+ for (int i = 0; i < cert_list_size; i++) {
+ gnutls_x509_crt_t cert;
+ int ret = gnutls_x509_crt_init(&cert);
+ if (ret != GNUTLS_E_SUCCESS) {
+ return ret;
+ }
+
+ ret = gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER);
+ if (ret != GNUTLS_E_SUCCESS) {
+ gnutls_x509_crt_deinit(cert);
+ return ret;
+ }
+
+ #ifdef DEBUG
+ if (VERBOSE_STATUS) {
+ char pin_base64[TLS_SHA256_BASE64_BUFLEN];
+ /* DEBUG: additionally compute and print the base64 pin.
+ * Not very efficient, but that's OK for DEBUG. */
+ ret = get_oob_key_pin(cert, pin_base64, sizeof(pin_base64), false);
+ if (ret == GNUTLS_E_SUCCESS) {
+ DEBUG_MSG("[tls_client] received pin: %s\n", pin_base64);
+ } else {
+ DEBUG_MSG("[tls_client] failed to convert received pin\n");
+ /* Now we hope that `ret` below can't differ. */
+ }
+ }
+ #endif
+ char cert_pin[TLS_SHA256_RAW_LEN];
+ /* Get raw pin and compare. */
+ ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), true);
+ gnutls_x509_crt_deinit(cert);
+ if (ret != GNUTLS_E_SUCCESS) {
+ return ret;
+ }
+ for (size_t j = 0; j < params->pins.len; ++j) {
+ const uint8_t *pin = params->pins.at[j];
+ if (memcmp(cert_pin, pin, TLS_SHA256_RAW_LEN) != 0)
+ continue; /* mismatch */
+ DEBUG_MSG("[tls_client] matched a configured pin no. %zd\n", j);
+ return GNUTLS_E_SUCCESS;
+ }
+ DEBUG_MSG("[tls_client] none of %zd configured pin(s) matched\n",
+ params->pins.len);
+ }
+
+ kr_log_error("[tls_client] no pin matched: %zu pins * %d certificates\n",
+ params->pins.len, cert_list_size);
+ return GNUTLS_E_CERTIFICATE_ERROR;
+
+#else /* TLS_CAN_USE_PINS */
+ kr_log_error("[tls_client] internal inconsistency: TLS_CAN_USE_PINS\n");
+ assert(false);
+ return GNUTLS_E_CERTIFICATE_ERROR;
+#endif
+}
+
+/**
+ * Verify that \param tls_session contains a valid X.509 certificate chain
+ * with given hostname.
+ *
+ * \returns GNUTLS_E_SUCCESS if certificate chain is valid, any other value is an error
+ */
+static int client_verify_certchain(gnutls_session_t tls_session, const char *hostname)
+{
+ if (!hostname) {
+ kr_log_error("[tls_client] internal config inconsistency: no hostname set\n");
+ assert(false);
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+
+ unsigned int status;
+ int ret = gnutls_certificate_verify_peers3(tls_session, hostname, &status);
+ if ((ret == GNUTLS_E_SUCCESS) && (status == 0)) {
+ return GNUTLS_E_SUCCESS;
+ }
+
+ if (ret == GNUTLS_E_SUCCESS) {
+ gnutls_datum_t msg;
+ ret = gnutls_certificate_verification_status_print(
+ status, gnutls_certificate_type_get(tls_session), &msg, 0);
+ if (ret == GNUTLS_E_SUCCESS) {
+ kr_log_error("[tls_client] failed to verify peer certificate: "
+ "%s\n", msg.data);
+ gnutls_free(msg.data);
+ } else {
+ kr_log_error("[tls_client] failed to verify peer certificate: "
+ "unable to print reason: %s (%s)\n",
+ gnutls_strerror(ret), gnutls_strerror_name(ret));
+ } /* gnutls_certificate_verification_status_print end */
+ } else {
+ kr_log_error("[tls_client] failed to verify peer certificate: "
+ "gnutls_certificate_verify_peers3 error: %s (%s)\n",
+ gnutls_strerror(ret), gnutls_strerror_name(ret));
+ } /* gnutls_certificate_verify_peers3 end */
+ return GNUTLS_E_CERTIFICATE_ERROR;
+}
+
+/**
+ * Verify that actual TLS security parameters of \param tls_session
+ * match requirements provided by user in tls_session->params.
+ * \returns GNUTLS_E_SUCCESS if requirements were met, any other value is an error
+ */
+static int client_verify_certificate(gnutls_session_t tls_session)
+{
+ struct tls_client_ctx_t *ctx = gnutls_session_get_ptr(tls_session);
+ assert(ctx->params != NULL);
+
+ if (ctx->params->insecure) {
+ return GNUTLS_E_SUCCESS;
+ }
+
+ gnutls_certificate_type_t cert_type = gnutls_certificate_type_get(tls_session);
+ if (cert_type != GNUTLS_CRT_X509) {
+ kr_log_error("[tls_client] invalid certificate type %i has been received\n",
+ cert_type);
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+ unsigned int cert_list_size = 0;
+ const gnutls_datum_t *cert_list =
+ gnutls_certificate_get_peers(tls_session, &cert_list_size);
+ if (cert_list == NULL || cert_list_size == 0) {
+ kr_log_error("[tls_client] empty certificate list\n");
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+
+ if (ctx->params->pins.len > 0)
+ /* check hash of the certificate but ignore everything else */
+ return client_verify_pin(cert_list_size, cert_list, ctx->params);
+ else
+ return client_verify_certchain(ctx->c.tls_session, ctx->params->hostname);
+}
+
+struct tls_client_ctx_t *tls_client_ctx_new(tls_client_param_t *entry,
+ struct worker_ctx *worker)
+{
+ struct tls_client_ctx_t *ctx = calloc(1, sizeof (struct tls_client_ctx_t));
+ if (!ctx) {
+ return NULL;
+ }
+ unsigned int flags = GNUTLS_CLIENT | GNUTLS_NONBLOCK
+#ifdef GNUTLS_ENABLE_FALSE_START
+ | GNUTLS_ENABLE_FALSE_START
+#endif
+ ;
+ int ret = gnutls_init(&ctx->c.tls_session, flags);
+ if (ret != GNUTLS_E_SUCCESS) {
+ tls_client_ctx_free(ctx);
+ return NULL;
+ }
+
+ ret = kres_gnutls_set_priority(ctx->c.tls_session);
+ if (ret != GNUTLS_E_SUCCESS) {
+ tls_client_ctx_free(ctx);
+ return NULL;
+ }
+
+ /* Must take a reference on parameters as the credentials are owned by it
+ * and must not be freed while the session is active. */
+ ++(entry->refs);
+ ctx->params = entry;
+
+ ret = gnutls_credentials_set(ctx->c.tls_session, GNUTLS_CRD_CERTIFICATE,
+ entry->credentials);
+ if (ret == GNUTLS_E_SUCCESS && entry->hostname) {
+ ret = gnutls_server_name_set(ctx->c.tls_session, GNUTLS_NAME_DNS,
+ entry->hostname, strlen(entry->hostname));
+ kr_log_verbose("[tls_client] set hostname, ret = %d\n", ret);
+ } else if (!entry->hostname) {
+ kr_log_verbose("[tls_client] no hostname\n");
+ }
+ if (ret != GNUTLS_E_SUCCESS) {
+ tls_client_ctx_free(ctx);
+ return NULL;
+ }
+
+ ctx->c.worker = worker;
+ ctx->c.client_side = true;
+
+ gnutls_transport_set_pull_function(ctx->c.tls_session, kres_gnutls_pull);
+ gnutls_transport_set_vec_push_function(ctx->c.tls_session, kres_gnutls_vec_push);
+ gnutls_transport_set_ptr(ctx->c.tls_session, ctx);
+ return ctx;
+}
+
+void tls_client_ctx_free(struct tls_client_ctx_t *ctx)
+{
+ if (ctx == NULL) {
+ return;
+ }
+
+ if (ctx->c.tls_session != NULL) {
+ gnutls_deinit(ctx->c.tls_session);
+ ctx->c.tls_session = NULL;
+ }
+
+ /* Must decrease the refcount for referenced parameters */
+ tls_client_param_unref(ctx->params);
+
+ free (ctx);
+}
+
+int tls_pull_timeout_func(gnutls_transport_ptr_t h, unsigned int ms)
+{
+ struct tls_common_ctx *t = (struct tls_common_ctx *)h;
+ assert(t != NULL);
+ ssize_t avail = t->nread - t->consumed;
+ DEBUG_MSG("[%s] timeout check: available: %zu\n",
+ t->client_side ? "tls_client" : "tls", avail);
+ if (avail <= 0) {
+ errno = EAGAIN;
+ return -1;
+ }
+ return avail;
+}
+
+int tls_client_connect_start(struct tls_client_ctx_t *client_ctx,
+ struct session *session,
+ tls_handshake_cb handshake_cb)
+{
+ if (session == NULL || client_ctx == NULL) {
+ return kr_error(EINVAL);
+ }
+
+ assert(session_flags(session)->outgoing && session_get_handle(session)->type == UV_TCP);
+
+ struct tls_common_ctx *ctx = &client_ctx->c;
+
+ gnutls_session_set_ptr(ctx->tls_session, client_ctx);
+ gnutls_handshake_set_timeout(ctx->tls_session, ctx->worker->engine->net.tcp.tls_handshake_timeout);
+ gnutls_transport_set_pull_timeout_function(ctx->tls_session, tls_pull_timeout_func);
+ session_tls_set_client_ctx(session, client_ctx);
+ ctx->handshake_cb = handshake_cb;
+ ctx->handshake_state = TLS_HS_IN_PROGRESS;
+ ctx->session = session;
+
+ tls_client_param_t *tls_params = client_ctx->params;
+ if (tls_params->session_data.data != NULL) {
+ gnutls_session_set_data(ctx->tls_session, tls_params->session_data.data,
+ tls_params->session_data.size);
+ }
+
+ /* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
+ while (ctx->handshake_state <= TLS_HS_IN_PROGRESS) {
+ int ret = tls_handshake(ctx, handshake_cb);
+ if (ret != kr_ok()) {
+ return ret;
+ }
+ }
+ return kr_ok();
+}
+
+tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx)
+{
+ return ctx->handshake_state;
+}
+
+int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state)
+{
+ if (state >= TLS_HS_LAST) {
+ return kr_error(EINVAL);
+ }
+ ctx->handshake_state = state;
+ return kr_ok();
+}
+
+int tls_client_ctx_set_session(struct tls_client_ctx_t *ctx, struct session *session)
+{
+ if (!ctx) {
+ return kr_error(EINVAL);
+ }
+ ctx->c.session = session;
+ return kr_ok();
+}
+
+#undef DEBUG_MSG
--- /dev/null
+/* Copyright (C) 2016 American Civil Liberties Union (ACLU)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#pragma once
+
+#include <uv.h>
+#include <gnutls/gnutls.h>
+#include <libknot/packet/pkt.h>
+#include "lib/defines.h"
+#include "lib/generic/array.h"
+#include "lib/generic/trie.h"
+#include "lib/utils.h"
+
+#define MAX_TLS_PADDING KR_EDNS_PAYLOAD
+#define TLS_MAX_UNCORK_RETRIES 100
+
+/* rfc 5476, 7.3 - handshake Protocol overview
+ * https://tools.ietf.org/html/rfc5246#page-33
+ * Message flow for a full handshake (only mandatory messages)
+ * ClientHello -------->
+ ServerHello
+ <-------- ServerHelloDone
+ ClientKeyExchange
+ Finished -------->
+ <-------- Finished
+ *
+ * See also https://blog.cloudflare.com/keyless-ssl-the-nitty-gritty-technical-details/
+ * So it takes 2 RTT.
+ * As we use session tickets, there are additional messages, add one RTT mode.
+ */
+ #define TLS_MAX_HANDSHAKE_TIME (KR_CONN_RTT_MAX * 3)
+
+/** Transport session (opaque). */
+struct session;
+
+struct tls_ctx_t;
+struct tls_client_ctx_t;
+struct tls_credentials {
+ int count;
+ char *tls_cert;
+ char *tls_key;
+ gnutls_certificate_credentials_t credentials;
+ time_t valid_until;
+ char *ephemeral_servicename;
+};
+
+
+#define TLS_SHA256_RAW_LEN 32 /* gnutls_hash_get_len(GNUTLS_DIG_SHA256) */
+/** Required buffer length for pin_sha256, including the zero terminator. */
+#define TLS_SHA256_BASE64_BUFLEN (((TLS_SHA256_RAW_LEN * 8 + 4) / 6) + 3 + 1)
+
+#if GNUTLS_VERSION_NUMBER >= 0x030400
+ #define TLS_CAN_USE_PINS 1
+#else
+ #define TLS_CAN_USE_PINS 0
+#endif
+
+
+/** TLS authentication parameters for a single address-port pair. */
+typedef struct {
+ uint32_t refs; /**< Reference count; consider TLS sessions in progress. */
+ bool insecure; /**< Use no authentication. */
+ const char *hostname; /**< Server name for SNI and certificate check, lowercased. */
+ array_t(const char *) ca_files; /**< Paths to certificate files; not really used. */
+ array_t(const uint8_t *) pins; /**< Certificate pins as raw unterminated strings.*/
+ gnutls_certificate_credentials_t credentials; /**< CA creds. in gnutls format. */
+ gnutls_datum_t session_data; /**< Session-resumption data gets stored here. */
+} tls_client_param_t;
+/** Holds configuration for TLS authentication for all potential servers.
+ * Special case: NULL pointer also means empty. */
+typedef trie_t tls_client_params_t;
+
+/** Get a pointer-to-pointer to TLS auth params.
+ * If it didn't exist, it returns NULL (if !do_insert) or pointer to NULL. */
+tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params,
+ const struct sockaddr *addr, bool do_insert);
+
+/** Get a pointer to TLS auth params or NULL. */
+static inline tls_client_param_t *
+ tls_client_param_get(tls_client_params_t *params, const struct sockaddr *addr)
+{
+ tls_client_param_t **pe = tls_client_param_getptr(¶ms, addr, false);
+ return pe ? *pe : NULL;
+}
+
+/** Allocate and initialize the structure (with ->ref = 1). */
+tls_client_param_t * tls_client_param_new();
+/** Reference-counted free(); any inside data is freed alongside. */
+void tls_client_param_unref(tls_client_param_t *entry);
+
+int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr);
+/** Free TLS authentication parameters. */
+void tls_client_params_free(tls_client_params_t *params);
+
+
+struct worker_ctx;
+struct qr_task;
+struct network;
+struct engine;
+
+typedef enum tls_client_hs_state {
+ TLS_HS_NOT_STARTED = 0,
+ TLS_HS_IN_PROGRESS,
+ TLS_HS_DONE,
+ TLS_HS_CLOSING,
+ TLS_HS_LAST
+} tls_hs_state_t;
+
+typedef int (*tls_handshake_cb) (struct session *session, int status);
+
+
+struct tls_common_ctx {
+ bool client_side;
+ gnutls_session_t tls_session;
+ tls_hs_state_t handshake_state;
+ struct session *session;
+ /* for reading from the network */
+ const uint8_t *buf;
+ ssize_t nread;
+ ssize_t consumed;
+ uint8_t recv_buf[16384];
+ tls_handshake_cb handshake_cb;
+ struct worker_ctx *worker;
+ size_t write_queue_size;
+};
+
+struct tls_ctx_t {
+ /*
+ * Since pointer to tls_ctx_t needs to be casted
+ * to tls_ctx_common in some functions,
+ * this field must be always at first position
+ */
+ struct tls_common_ctx c;
+ struct tls_credentials *credentials;
+};
+
+struct tls_client_ctx_t {
+ /*
+ * Since pointer to tls_client_ctx_t needs to be casted
+ * to tls_ctx_common in some functions,
+ * this field must be always at first position
+ */
+ struct tls_common_ctx c;
+ tls_client_param_t *params; /**< It's reference-counted. */
+};
+
+/*! Create an empty TLS context in query context */
+struct tls_ctx_t* tls_new(struct worker_ctx *worker);
+
+/*! Close a TLS context (call gnutls_bye()) */
+void tls_close(struct tls_common_ctx *ctx);
+
+/*! Release a TLS context */
+void tls_free(struct tls_ctx_t* tls);
+
+/*! Push new data to TLS context for sending */
+int tls_write(uv_write_t *req, uv_handle_t* handle, knot_pkt_t * pkt, uv_write_cb cb);
+
+/*! Unwrap incoming data from a TLS stream and pass them to TCP session.
+ * @return the number of newly-completed requests (>=0) or an error code
+ */
+ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nread);
+
+/*! Set TLS certificate and key from files. */
+int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key);
+
+/*! Borrow TLS credentials for context. */
+struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials);
+
+/*! Release TLS credentials for context (decrements refcount or frees). */
+int tls_credentials_release(struct tls_credentials *tls_credentials);
+
+/*! Free TLS credentials, must not be called if it holds positive refcount. */
+void tls_credentials_free(struct tls_credentials *tls_credentials);
+
+/*! Log DNS-over-TLS OOB key-pin form of current credentials:
+ * https://tools.ietf.org/html/rfc7858#appendix-A */
+void tls_credentials_log_pins(struct tls_credentials *tls_credentials);
+
+/*! Generate new ephemeral TLS credentials. */
+struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine);
+
+/*! Get TLS handshake state. */
+tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx);
+
+/*! Set TLS handshake state. */
+int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state);
+
+
+/*! Allocate new client TLS context */
+struct tls_client_ctx_t *tls_client_ctx_new(tls_client_param_t *entry,
+ struct worker_ctx *worker);
+
+/*! Free client TLS context */
+void tls_client_ctx_free(struct tls_client_ctx_t *ctx);
+
+int tls_client_connect_start(struct tls_client_ctx_t *client_ctx,
+ struct session *session,
+ tls_handshake_cb handshake_cb);
+
+int tls_client_ctx_set_session(struct tls_client_ctx_t *ctx, struct session *session);
+
+
+/* Session tickets, server side. Implementation in ./tls_session_ticket-srv.c */
+
+/*! Opaque struct used by tls_session_ticket_* functions. */
+struct tls_session_ticket_ctx;
+
+/*! Suggested maximum reasonable secret length. */
+#define TLS_SESSION_TICKET_SECRET_MAX_LEN 1024
+
+/*! Create a session ticket context and initialize it (secret gets copied inside).
+ *
+ * Passing zero-length secret implies using a random key, i.e. not synchronized
+ * between multiple instances.
+ *
+ * Beware that knowledge of the secret (if nonempty) breaks forward secrecy,
+ * so you should rotate the secret regularly and securely erase all past secrets.
+ * With TLS < 1.3 it's probably too risky to set nonempty secret.
+ */
+struct tls_session_ticket_ctx * tls_session_ticket_ctx_create(
+ uv_loop_t *loop, const char *secret, size_t secret_len);
+
+/*! Try to enable session tickets for a server session. */
+void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session);
+
+/*! Free all resources of the session ticket context. NULL is accepted as well. */
+void tls_session_ticket_ctx_destroy(struct tls_session_ticket_ctx *ctx);
+
--- /dev/null
+/*
+ * Copyright (C) 2016 American Civil Liberties Union (ACLU)
+ * Copyright (C) 2016-2017 CZ.NIC, z.s.p.o.
+ *
+ * Initial Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
+ *
+ * This program is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option)
+ * any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <sys/file.h>
+#include <unistd.h>
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
+#include <gnutls/crypto.h>
+
+#include "worker.h"
+#include "tls.h"
+
+#define EPHEMERAL_PRIVKEY_FILENAME "ephemeral_key.pem"
+#define INVALID_HOSTNAME "dns-over-tls.invalid"
+#define EPHEMERAL_CERT_EXPIRATION_SECONDS (60*60*24*90)
+
+/* This is an attempt to grab an exclusive, advisory, non-blocking
+ * lock based on a filename. At the moment it's POSIX-only, but it
+ * should be abstract enough of an interface to make an implementation
+ * for non-posix systems if anyone cares. */
+typedef int lock_t;
+static bool _lock_is_invalid(lock_t lock)
+{
+ return lock == -1;
+}
+/* a blocking lock on a given filename */
+static lock_t _lock_filename(const char *fname)
+{
+ lock_t lockfd = open(fname, O_RDONLY|O_CREAT, 0400);
+ if (lockfd == -1)
+ return lockfd;
+ /* this should be a non-blocking lock */
+ if (flock(lockfd, LOCK_EX | LOCK_NB) != 0) {
+ close(lockfd);
+ return -1;
+ }
+ return lockfd; /* for cleanup later */
+}
+static void _lock_unlock(lock_t *lock, const char *fname)
+{
+ if (lock && !_lock_is_invalid(*lock)) {
+ flock(*lock, LOCK_UN);
+ close(*lock);
+ *lock = -1;
+ unlink(fname); /* ignore errors */
+ }
+}
+
+static gnutls_x509_privkey_t get_ephemeral_privkey ()
+{
+ gnutls_x509_privkey_t privkey = NULL;
+ int err;
+ gnutls_datum_t data = { .data = NULL, .size = 0 };
+ lock_t lock;
+ int datafd = -1;
+
+ /* Take a lock to ensure that two daemons started concurrently
+ * with a shared cache don't both create the same privkey: */
+ lock = _lock_filename(EPHEMERAL_PRIVKEY_FILENAME ".lock");
+ if (_lock_is_invalid(lock)) {
+ kr_log_error("[tls] unable to lock lockfile " EPHEMERAL_PRIVKEY_FILENAME ".lock\n");
+ goto done;
+ }
+
+ if ((err = gnutls_x509_privkey_init (&privkey)) < 0) {
+ kr_log_error("[tls] gnutls_x509_privkey_init() failed: %d (%s)\n",
+ err, gnutls_strerror_name(err));
+ goto done;
+ }
+
+ /* read from cache file (we assume that we've chdir'ed
+ * already, so we're just looking for the file in the
+ * cachedir. */
+ datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_RDONLY);
+ if (datafd != -1) {
+ struct stat stat;
+ ssize_t bytes_read;
+ if (fstat(datafd, &stat)) {
+ kr_log_error("[tls] unable to stat ephemeral private key " EPHEMERAL_PRIVKEY_FILENAME "\n");
+ goto bad_data;
+ }
+ data.data = gnutls_malloc(stat.st_size);
+ if (data.data == NULL) {
+ kr_log_error("[tls] unable to allocate memory for reading ephemeral private key\n");
+ goto bad_data;
+ }
+ data.size = stat.st_size;
+ bytes_read = read(datafd, data.data, stat.st_size);
+ if (bytes_read != stat.st_size) {
+ kr_log_error("[tls] unable to read ephemeral private key\n");
+ goto bad_data;
+ }
+ if ((err = gnutls_x509_privkey_import (privkey, &data, GNUTLS_X509_FMT_PEM)) < 0) {
+ kr_log_error("[tls] gnutls_x509_privkey_import() failed: %d (%s)\n",
+ err, gnutls_strerror_name(err));
+ /* goto bad_data; */
+ bad_data:
+ close(datafd);
+ datafd = -1;
+ }
+ if (data.data != NULL) {
+ gnutls_free(data.data);
+ data.data = NULL;
+ }
+ }
+ if (datafd == -1) {
+ /* if loading failed, then generate ... */
+#if GNUTLS_VERSION_NUMBER >= 0x030500
+ if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_ECDSA, GNUTLS_CURVE_TO_BITS(GNUTLS_ECC_CURVE_SECP256R1), 0)) < 0) {
+#else
+ if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_RSA, gnutls_sec_param_to_pk_bits(GNUTLS_PK_RSA, GNUTLS_SEC_PARAM_MEDIUM), 0)) < 0) {
+#endif
+ kr_log_error("[tls] gnutls_x509_privkey_init() failed: %d (%s)\n",
+ err, gnutls_strerror_name(err));
+ gnutls_x509_privkey_deinit(privkey);
+ goto done;
+ }
+ /* ... and save */
+ kr_log_info("[tls] Stashing ephemeral private key in " EPHEMERAL_PRIVKEY_FILENAME "\n");
+ if ((err = gnutls_x509_privkey_export2(privkey, GNUTLS_X509_FMT_PEM, &data)) < 0) {
+ kr_log_error("[tls] gnutls_x509_privkey_export2() failed: %d (%s), not storing\n",
+ err, gnutls_strerror_name(err));
+ } else {
+ datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_WRONLY|O_CREAT, 0600);
+ if (datafd == -1) {
+ kr_log_error("[tls] failed to open " EPHEMERAL_PRIVKEY_FILENAME " to store the ephemeral key\n");
+ } else {
+ ssize_t bytes_written;
+ bytes_written = write(datafd, data.data, data.size);
+ if (bytes_written != data.size)
+ kr_log_error("[tls] failed to write %d octets to "
+ EPHEMERAL_PRIVKEY_FILENAME
+ " (%zd written)\n",
+ data.size, bytes_written);
+ }
+ }
+ }
+ done:
+ _lock_unlock(&lock, EPHEMERAL_PRIVKEY_FILENAME ".lock");
+ if (datafd != -1) {
+ close(datafd);
+ }
+ if (data.data != NULL) {
+ gnutls_free(data.data);
+ }
+ return privkey;
+}
+
+static gnutls_x509_crt_t get_ephemeral_cert(gnutls_x509_privkey_t privkey, const char *servicename, time_t invalid_before, time_t valid_until)
+{
+ gnutls_x509_crt_t cert = NULL;
+ int err;
+ /* need a random buffer of bytes */
+ uint8_t serial[16];
+ gnutls_rnd(GNUTLS_RND_NONCE, serial, sizeof(serial));
+ /* clear the left-most bit to avoid signedness confusion: */
+ serial[0] &= 0x8f;
+ size_t namelen = strlen(servicename);
+
+#define gtx(fn, ...) \
+ if ((err = fn ( __VA_ARGS__ )) != GNUTLS_E_SUCCESS) { \
+ kr_log_error("[tls] " #fn "() failed: %d (%s)\n", \
+ err, gnutls_strerror_name(err)); \
+ goto bad; }
+
+ gtx(gnutls_x509_crt_init, &cert);
+ gtx(gnutls_x509_crt_set_activation_time, cert, invalid_before);
+ gtx(gnutls_x509_crt_set_ca_status, cert, 0);
+ gtx(gnutls_x509_crt_set_expiration_time, cert, valid_until);
+ gtx(gnutls_x509_crt_set_key, cert, privkey);
+ gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_CLIENT, 0);
+ gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_SERVER, 0);
+ gtx(gnutls_x509_crt_set_key_usage, cert, GNUTLS_KEY_DIGITAL_SIGNATURE);
+ gtx(gnutls_x509_crt_set_serial, cert, serial, sizeof(serial));
+ gtx(gnutls_x509_crt_set_subject_alt_name, cert, GNUTLS_SAN_DNSNAME, servicename, namelen, GNUTLS_FSAN_SET);
+ gtx(gnutls_x509_crt_set_dn_by_oid,cert, GNUTLS_OID_X520_COMMON_NAME, 0, servicename, namelen);
+ gtx(gnutls_x509_crt_set_version, cert, 3);
+ gtx(gnutls_x509_crt_sign2,cert, cert, privkey, GNUTLS_DIG_SHA256, 0); /* self-sign, since it doesn't look like we can just stub-sign */
+#undef gtx
+
+ return cert;
+bad:
+ gnutls_x509_crt_deinit(cert);
+ return NULL;
+}
+
+struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine)
+{
+ struct tls_credentials *creds = NULL;
+ gnutls_x509_privkey_t privkey = NULL;
+ gnutls_x509_crt_t cert = NULL;
+ int err;
+ time_t now = time(NULL);
+
+ creds = calloc(1, sizeof(*creds));
+ if (!creds) {
+ kr_log_error("[tls] failed to allocate memory for ephemeral credentials\n");
+ return NULL;
+ }
+ if ((err = gnutls_certificate_allocate_credentials(&(creds->credentials))) < 0) {
+ kr_log_error("[tls] failed to allocate memory for ephemeral credentials\n");
+ goto failure;
+ }
+
+ creds->valid_until = now + EPHEMERAL_CERT_EXPIRATION_SECONDS;
+ creds->ephemeral_servicename = strdup(engine_get_hostname(engine));
+ if (creds->ephemeral_servicename == NULL) {
+ kr_log_error("[tls] could not get server's hostname, using '" INVALID_HOSTNAME "' instead\n");
+ if ((creds->ephemeral_servicename = strdup(INVALID_HOSTNAME)) == NULL) {
+ kr_log_error("[tls] failed to allocate memory for ephemeral credentials\n");
+ goto failure;
+ }
+ }
+ if ((privkey = get_ephemeral_privkey()) == NULL) {
+ goto failure;
+ }
+ if ((cert = get_ephemeral_cert(privkey, creds->ephemeral_servicename, now - 60*15, creds->valid_until)) == NULL) {
+ goto failure;
+ }
+ if ((err = gnutls_certificate_set_x509_key(creds->credentials, &cert, 1, privkey)) < 0) {
+ kr_log_error("[tls] failed to set up ephemeral credentials\n");
+ goto failure;
+ }
+ gnutls_x509_privkey_deinit(privkey);
+ gnutls_x509_crt_deinit(cert);
+ return creds;
+ failure:
+ gnutls_x509_privkey_deinit(privkey);
+ gnutls_x509_crt_deinit(cert);
+ tls_credentials_free(creds);
+ return NULL;
+}
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <assert.h>
+#include <inttypes.h>
+#include <stdbool.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/time.h>
+
+#include <gnutls/gnutls.h>
+#include <gnutls/crypto.h>
+#include <uv.h>
+
+#include "lib/utils.h"
+
+/* Style: "local/static" identifiers are usually named tst_* */
+
+/** The number of seconds between synchronized rotation of TLS session ticket key. */
+#define TST_KEY_LIFETIME 4096
+
+/** Value from gnutls:lib/ext/session_ticket.c
+ * Beware: changing this needs to change the hashing implementation. */
+#define SESSION_KEY_SIZE 64
+
+/** Compile-time support for setting the secret. */
+/* This is not secure with TLS <= 1.2 but TLS 1.3 and secure configuration
+ * is not available in GnuTLS yet. See https://gitlab.com/gnutls/gnutls/issues/477
+#ifndef TLS_SESSION_RESUMPTION_SYNC
+ #define TLS_SESSION_RESUMPTION_SYNC (GNUTLS_VERSION_NUMBER >= 0x030603)
+#endif
+*/
+
+#if GNUTLS_VERSION_NUMBER < 0x030400
+ /* It's of little use anyway. We may get the secret through lua,
+ * which creates a copy outside of our control. */
+ #define gnutls_memset memset
+#endif
+
+#ifdef GNUTLS_DIG_SHA3_512
+ #define TST_HASH GNUTLS_DIG_SHA3_512
+#else
+ #define TST_HASH abort()
+#endif
+
+/** Fields are internal to tst_key_* functions. */
+typedef struct tls_session_ticket_ctx {
+ uv_timer_t timer; /**< timer for rotation of the key */
+ unsigned char key[SESSION_KEY_SIZE]; /**< the key itself */
+ bool has_secret; /**< false -> key is random for each epoch */
+ uint16_t hash_len; /**< length of `hash_data` */
+ char hash_data[]; /**< data to hash to obtain `key`;
+ * it's `time_t epoch` and then the secret string */
+} tst_ctx_t;
+
+/** Check invariants, based on gnutls version. */
+static bool tst_key_invariants(void)
+{
+ static int result = 0; /*< cache for multiple invocations */
+ if (result) return result > 0;
+ bool ok = true;
+ #if TLS_SESSION_RESUMPTION_SYNC
+ /* SHA3-512 output size may never change, but let's check it anyway :-) */
+ ok = ok && gnutls_hash_get_len(TST_HASH) == SESSION_KEY_SIZE;
+ #endif
+ /* The ticket key size might change in a different gnutls version. */
+ gnutls_datum_t key = { 0, 0 };
+ ok = ok && gnutls_session_ticket_key_generate(&key) == 0
+ && key.size == SESSION_KEY_SIZE;
+ free(key.data);
+ result = ok ? 1 : -1;
+ return ok;
+}
+
+/** Create the internal structures and copy the secret. Beware: secret must be kept secure. */
+static tst_ctx_t * tst_key_create(const char *secret, size_t secret_len, uv_loop_t *loop)
+{
+ const size_t hash_len = sizeof(time_t) + secret_len;
+ if (secret_len &&
+ (!secret || hash_len > UINT16_MAX || hash_len < secret_len)) {
+ assert(!EINVAL);
+ return NULL;
+ /* reasonable secret_len is best enforced in config API */
+ }
+ if (!tst_key_invariants()) {
+ assert(!EFAULT);
+ return NULL;
+ }
+ #if !TLS_SESSION_RESUMPTION_SYNC
+ if (secret_len) {
+ kr_log_error("[tls] session ticket: secrets were not enabled at compile-time (your GnuTLS version is not supported)\n");
+ return NULL; /* ENOTSUP */
+ }
+ #endif
+
+ tst_ctx_t *ctx = malloc(sizeof(*ctx) + hash_len); /* can be slightly longer */
+ if (!ctx) return NULL;
+ ctx->has_secret = secret_len > 0;
+ ctx->hash_len = hash_len;
+ if (secret_len) {
+ memcpy(ctx->hash_data + sizeof(time_t), secret, secret_len);
+ }
+
+ if (uv_timer_init(loop, &ctx->timer) != 0) {
+ free(ctx);
+ return NULL;
+ }
+ ctx->timer.data = ctx;
+ return ctx;
+}
+
+/** Random variant of secret rotation: generate into key_tmp and copy. */
+static int tst_key_get_random(tst_ctx_t *ctx)
+{
+ gnutls_datum_t key_tmp = { NULL, 0 };
+ int err = gnutls_session_ticket_key_generate(&key_tmp);
+ if (err) return kr_error(err);
+ if (key_tmp.size != SESSION_KEY_SIZE) {
+ assert(!EFAULT);
+ return kr_error(EFAULT);
+ }
+ memcpy(ctx->key, key_tmp.data, SESSION_KEY_SIZE);
+ gnutls_memset(key_tmp.data, 0, SESSION_KEY_SIZE);
+ free(key_tmp.data);
+ return kr_ok();
+}
+
+/** Recompute the session ticket key, if epoch has changed or forced. */
+static int tst_key_update(tst_ctx_t *ctx, time_t epoch, bool force_update)
+{
+ if (!ctx || ctx->hash_len < sizeof(epoch)) {
+ assert(!EINVAL);
+ return kr_error(EINVAL);
+ }
+ /* documented limitation: time_t and endianess must match
+ * on instances sharing a secret */
+ if (!force_update && memcmp(ctx->hash_data, &epoch, sizeof(epoch)) == 0) {
+ return kr_ok(); /* we are up to date */
+ }
+ memcpy(ctx->hash_data, &epoch, sizeof(epoch));
+
+ if (!ctx->has_secret) {
+ return tst_key_get_random(ctx);
+ }
+ /* Otherwise, deterministic variant of secret rotation, if supported. */
+ #if !TLS_SESSION_RESUMPTION_SYNC
+ assert(false);
+ return kr_error(ENOTSUP);
+ #else
+ int err = gnutls_hash_fast(TST_HASH, ctx->hash_data,
+ ctx->hash_len, ctx->key);
+ return err == 0 ? kr_ok() : kr_error(err);
+ #endif
+}
+
+/** Free all resources of the key (securely). */
+static void tst_key_destroy(uv_handle_t *timer)
+{
+ assert(timer);
+ tst_ctx_t *ctx = timer->data;
+ assert(ctx);
+ gnutls_memset(ctx, 0, offsetof(tst_ctx_t, hash_data) + ctx->hash_len);
+ free(ctx);
+}
+
+static void tst_key_check(uv_timer_t *timer, bool force_update);
+static void tst_timer_callback(uv_timer_t *timer)
+{
+ tst_key_check(timer, false);
+}
+
+/** Update the ST key if needed and reschedule itself via the timer. */
+static void tst_key_check(uv_timer_t *timer, bool force_update)
+{
+ tst_ctx_t *stst = (tst_ctx_t *)timer->data;
+ /* Compute the current epoch. */
+ struct timeval now;
+ if (gettimeofday(&now, NULL)) {
+ kr_log_error("[tls] session ticket: gettimeofday failed, %s\n",
+ strerror(errno));
+ return;
+ }
+ uv_update_time(timer->loop); /* to have sync. between real and mono time */
+ const time_t epoch = now.tv_sec / TST_KEY_LIFETIME;
+ /* Update the key; new sessions will fetch it from the location.
+ * Old ones hopefully can't get broken by that; documentation
+ * for gnutls_session_ticket_enable_server() doesn't say. */
+ int err = tst_key_update(stst, epoch, force_update);
+ if (err) {
+ assert(err != kr_error(EINVAL));
+ kr_log_error("[tls] session ticket: failed rotation, err = %d\n", err);
+ }
+ /* Reschedule. */
+ const time_t tv_sec_next = (epoch + 1) * TST_KEY_LIFETIME;
+ const uint64_t ms_until_second = 1000 - (now.tv_usec + 501) / 1000;
+ const uint64_t remain_ms = (tv_sec_next - now.tv_sec - 1) * (uint64_t)1000
+ + ms_until_second + 1;
+ /* ^ +1 because we don't want to wake up half a millisecond before the epoch! */
+ assert(remain_ms < (TST_KEY_LIFETIME + 1 /*rounding tolerance*/) * 1000);
+ kr_log_verbose("[tls] session ticket: epoch %"PRIu64
+ ", scheduling rotation check in %"PRIu64" ms\n",
+ (uint64_t)epoch, remain_ms);
+ err = uv_timer_start(timer, &tst_timer_callback, remain_ms, 0);
+ if (err) {
+ assert(false);
+ kr_log_error("[tls] session ticket: failed to schedule, err = %d\n", err);
+ }
+}
+
+/* Implementation for prototypes from ./tls.h */
+
+void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session)
+{
+ assert(ctx && session);
+ const gnutls_datum_t gd = {
+ .size = SESSION_KEY_SIZE,
+ .data = ctx->key,
+ };
+ int err = gnutls_session_ticket_enable_server(session, &gd);
+ if (err) {
+ kr_log_error("[tls] failed to enable session tickets: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ /* but continue without tickets */
+ }
+}
+
+tst_ctx_t * tls_session_ticket_ctx_create(uv_loop_t *loop, const char *secret,
+ size_t secret_len)
+{
+ assert(loop && (!secret_len || secret));
+ #if GNUTLS_VERSION_NUMBER < 0x030500
+ /* We would need different SESSION_KEY_SIZE; avoid assert. */
+ return NULL;
+ #endif
+ tst_ctx_t *ctx = tst_key_create(secret, secret_len, loop);
+ if (ctx) {
+ tst_key_check(&ctx->timer, true);
+ }
+ return ctx;
+}
+
+void tls_session_ticket_ctx_destroy(tst_ctx_t *ctx)
+{
+ if (ctx == NULL) {
+ return;
+ }
+ uv_close((uv_handle_t *)&ctx->timer, &tst_key_destroy);
+}
+
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "kresconfig.h"
+#include "udp_queue.h"
+
+#include "session.h"
+#include "worker.h"
+#include "lib/generic/array.h"
+#include "lib/utils.h"
+
+struct qr_task;
+
+#include <assert.h>
+#include <sys/socket.h>
+
+
+#if !ENABLE_SENDMMSG
+int udp_queue_init_global(uv_loop_t *loop)
+{
+ return 0;
+}
+/* Appease the linker in case this unused call isn't optimized out. */
+void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task)
+{
+ abort();
+}
+#else
+
+/* LATER: it might be useful to have this configurable during runtime,
+ * but the structures below would have to change a little (broken up). */
+#define UDP_QUEUE_LEN 64
+
+/** A queue of up to UDP_QUEUE_LEN messages, meant for the same socket. */
+typedef struct {
+ int len; /**< The number of messages in the queue: 0..UDP_QUEUE_LEN */
+ struct mmsghdr msgvec[UDP_QUEUE_LEN]; /**< Parameter for sendmmsg() */
+ struct {
+ struct qr_task *task; /**< Links for completion callbacks. */
+ struct iovec msg_iov[1]; /**< storage for .msgvec[i].msg_iov */
+ } items[UDP_QUEUE_LEN];
+} udp_queue_t;
+
+static udp_queue_t * udp_queue_create()
+{
+ udp_queue_t *q = calloc(1, sizeof(*q));
+ for (int i = 0; i < UDP_QUEUE_LEN; ++i) {
+ struct msghdr *mhi = &q->msgvec[i].msg_hdr;
+ /* These shall remain always the same. */
+ mhi->msg_iov = q->items[i].msg_iov;
+ mhi->msg_iovlen = 1;
+ /* msg_name and msg_namelen will be per-call,
+ * and the rest is OK to remain zeroed all the time. */
+ }
+ return q;
+}
+
+/** Global state for udp_queue_*. Note: we never free the pointed-to memory. */
+struct {
+ /** Singleton map: fd -> udp_queue_t, as a simple array of pointers. */
+ udp_queue_t **udp_queues;
+ int udp_queues_len;
+
+ /** List of FD numbers that might have a non-empty queue. */
+ array_t(int) waiting_fds;
+
+ uv_check_t check_handle;
+} static state = {0};
+
+/** Empty the given queue. The queue is assumed to exist (but may be empty). */
+static void udp_queue_send(int fd)
+{
+ udp_queue_t *const q = state.udp_queues[fd];
+ if (!q->len) return;
+ int sent_len = sendmmsg(fd, q->msgvec, q->len, 0);
+ /* ATM we don't really do anything about failures. */
+ int err = sent_len < 0 ? errno : EAGAIN /* unknown error, really */;
+ if (unlikely(sent_len != q->len)) {
+ if (err != EWOULDBLOCK) {
+ kr_log_error("ERROR: udp sendmmsg() sent %d / %d; %s\n",
+ sent_len, q->len, strerror(err));
+ } else {
+ const uint64_t stamp_now = kr_now();
+ static uint64_t stamp_last = 0;
+ if (stamp_now > stamp_last + 60*1000) {
+ kr_log_info("WARNING: dropped UDP reply packet(s) due to network overload (reported at most once per minute)\n");
+ stamp_last = stamp_now;
+ }
+ }
+ }
+ for (int i = 0; i < q->len; ++i) {
+ qr_task_on_send(q->items[i].task, NULL, i < sent_len ? 0 : err);
+ worker_task_unref(q->items[i].task);
+ }
+ q->len = 0;
+}
+
+/** Periodical callback to send all queued packets. */
+static void udp_queue_check(uv_check_t *handle)
+{
+ for (int i = 0; i < state.waiting_fds.len; ++i) {
+ udp_queue_send(state.waiting_fds.at[i]);
+ }
+ state.waiting_fds.len = 0;
+}
+
+int udp_queue_init_global(uv_loop_t *loop)
+{
+ int ret = uv_check_init(loop, &state.check_handle);
+ if (!ret) ret = uv_check_start(&state.check_handle, udp_queue_check);
+ return ret;
+}
+
+void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task)
+{
+ if (fd < 0) {
+ kr_log_error("ERROR: called udp_queue_push(fd = %d, ...)\n", fd);
+ abort();
+ }
+ worker_task_ref(task);
+ /* Get a valid correct queue. */
+ if (fd >= state.udp_queues_len) {
+ const int new_len = fd + 1;
+ state.udp_queues = realloc(state.udp_queues,
+ sizeof(state.udp_queues[0]) * new_len);
+ if (!state.udp_queues) abort();
+ memset(state.udp_queues + state.udp_queues_len, 0,
+ sizeof(state.udp_queues[0]) * (new_len - state.udp_queues_len));
+ state.udp_queues_len = new_len;
+ }
+ if (unlikely(state.udp_queues[fd] == NULL))
+ state.udp_queues[fd] = udp_queue_create();
+ udp_queue_t *const q = state.udp_queues[fd];
+
+ /* Append to the queue */
+ struct sockaddr *sa = (struct sockaddr *)/*const-cast*/req->qsource.addr;
+ q->msgvec[q->len].msg_hdr.msg_name = sa;
+ q->msgvec[q->len].msg_hdr.msg_namelen = kr_sockaddr_len(sa);
+ q->items[q->len].task = task;
+ q->items[q->len].msg_iov[0] = (struct iovec){
+ .iov_base = req->answer->wire,
+ .iov_len = req->answer->size,
+ };
+ if (q->len == 0)
+ array_push(state.waiting_fds, fd);
+ ++(q->len);
+
+ if (q->len >= UDP_QUEUE_LEN) {
+ assert(q->len == UDP_QUEUE_LEN);
+ udp_queue_send(fd);
+ /* We don't need to search state.waiting_fds;
+ * anyway, it's more efficient to let the hook do that. */
+ }
+}
+
+#endif
+
--- /dev/null
+/* Copyright (C) 2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <uv.h>
+struct kr_request;
+struct qr_task;
+
+/** Initialize the global state for udp_queue. */
+int udp_queue_init_global(uv_loop_t *loop);
+
+/** Send req->answer via UDP, possibly not immediately. */
+void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task);
+
--- /dev/null
+#include <assert.h>
+#include <gnutls/gnutls.h>
+#include <gnutls/crypto.h>
+
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+
+#include "kresconfig.h"
+#include "lib/utils.h"
+#include "modules/sysrepo/common/sysrepo.h"
+
+#include "worker.h"
+#include "watcher.h"
+#include "sr_subscriptions.h"
+
+/* 12 hours interval */
+#define TST_SECRET_CYCLE 12*60*60*1000
+
+
+/* default configuration */
+struct server_config default_config = {
+ .auto_start = false,
+ .auto_cache_gc = true,
+ .kresd_instances = 1
+};
+
+static void tst_secret_check(uv_timer_t *timer, bool timer_update);
+static void tst_timer_callback(uv_timer_t *timer)
+{
+ tst_secret_check(timer, true);
+}
+
+static tst_secret_ctx_t * tst_secret_create(uv_loop_t *loop)
+{
+ struct tst_secret_ctx *ctx = malloc(sizeof(*ctx));
+ if (!ctx) return NULL;
+
+ if (uv_timer_init(loop, &ctx->timer) != 0) {
+ free(ctx);
+ return NULL;
+ }
+ ctx->timer.data = ctx;
+ return ctx;
+}
+
+static void tst_secret_check(uv_timer_t *timer, bool timer_update)
+{
+ int ret = 0;
+ uv_update_time(timer->loop);
+
+ if(timer_update) {
+ uint8_t *base64;
+ gnutls_datum_t key_tmp = { NULL, 0 };
+ ret = gnutls_session_ticket_key_generate(&key_tmp);
+ if (ret){
+ kr_log_error("[watcher] failed to generate tls sticket secret, %s", strerror(ret));
+ return;
+ }
+
+ int32_t len = base64_encode_alloc((uint8_t *)key_tmp.data, sizeof key_tmp.data, &base64);
+ if (len < 0) {
+ kr_log_error("[watcher] failed to encode tls sticket secret in base64");
+ return;
+ }
+
+ base64[len-1] = '\0';
+ char *secret = (char*) base64;
+
+ kr_log_info("[watcher] generated new secret for tls session ticket\n");
+ ret = set_tst_secret(secret);
+
+ free(key_tmp.data);
+ free(secret);
+ }
+ uv_timer_start(timer, &tst_timer_callback, TST_SECRET_CYCLE, 0);
+}
+
+static tst_secret_ctx_t * tst_secret_ctx_create(uv_loop_t *loop, bool timer_update)
+{
+ assert(loop);
+ tst_secret_ctx_t *ctx = tst_secret_create(loop);
+ if (ctx) {
+ tst_secret_check(&ctx->timer, timer_update);
+ }
+ kr_log_info("[watcher] new context for tls_sticket_secret created\n");
+ return ctx;
+}
+
+static void tst_secret_destroy(uv_handle_t *timer)
+{
+ assert(timer);
+ struct tst_secret_ctx *ctx = timer->data;
+ assert(ctx);
+ free(ctx);
+}
+
+
+static void tst_secret_timer_destroy(tst_secret_ctx_t *ctx)
+{
+ if (ctx == NULL) {
+ return;
+ }
+ uv_close((uv_handle_t *)&ctx->timer, &tst_secret_destroy);
+}
+
+int tst_secret_timer_init(uv_loop_t *loop)
+{
+ tst_secret_ctx_t *tst_ctx = the_worker->engine->watcher.tst_secret;
+
+ tst_secret_timer_destroy(tst_ctx);
+
+ tst_ctx = tst_secret_ctx_create(loop, false);
+ the_worker->engine->watcher.tst_secret = tst_ctx;
+ if (the_worker->engine->watcher.tst_secret == NULL) {
+ kr_log_error("[watcher] failed to create tls session ticket secret context");
+ return 1;
+ }
+ return 0;
+}
+
+void watcher_init(struct watcher_context *watcher, uv_loop_t *loop)
+{
+ assert(watcher != NULL);
+ if (watcher != NULL) {
+ watcher->loop = loop;
+ watcher->config = default_config;
+
+ /* Init sysrepo context */
+ watcher->sysrepo = sysrepo_ctx_init();
+
+ /* Init timer for tls session ticket secret generation */
+ watcher->tst_secret = tst_secret_ctx_create(loop, true);
+ }
+}
+
+void watcher_deinit(struct watcher_context *watcher)
+{
+ assert(watcher);
+ if (watcher != NULL) {
+ sysrepo_ctx_deinit(watcher->sysrepo);
+ tst_secret_timer_destroy(watcher->tst_secret);
+ }
+}
\ No newline at end of file
--- /dev/null
+#pragma once
+
+#include <stdbool.h>
+#include <uv.h>
+
+#include <sysrepo.h>
+#include <libyang/libyang.h>
+#include <systemd/sd-bus.h>
+
+#include "modules/sysrepo/common/sysrepo.h"
+
+
+typedef struct server_config {
+ bool auto_start;
+ bool auto_cache_gc;
+ uint8_t kresd_instances;
+};
+
+typedef struct tst_secret_ctx {
+ uv_timer_t timer;
+} tst_secret_ctx_t;
+
+typedef struct sdbus_ctx {
+
+} sdbus_ctx_t;
+
+struct watcher_context {
+ uv_loop_t *loop;
+ sysrepo_uv_ctx_t *sysrepo;
+ sdbus_ctx_t *sdbus;
+ tst_secret_ctx_t *tst_secret;
+ struct server_config config;
+};
+
+void watcher_init(struct watcher_context *watcher, uv_loop_t *loop);
+
+void watcher_deinit(struct watcher_context *watcher);
\ No newline at end of file
--- /dev/null
+/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "kresconfig.h"
+#include "worker.h"
+
+#include <uv.h>
+#include <lua.h>
+#include <lauxlib.h>
+#include <libknot/packet/pkt.h>
+#include <libknot/descriptor.h>
+#include <contrib/ucw/lib.h>
+#include <contrib/ucw/mempool.h>
+#include <contrib/wire.h>
+#if defined(__GLIBC__) && defined(_GNU_SOURCE)
+#include <malloc.h>
+#endif
+#include <assert.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <gnutls/gnutls.h>
+
+#include "bindings/api.h"
+#include "engine.h"
+#include "io.h"
+#include "session.h"
+#include "tls.h"
+#include "udp_queue.h"
+#include "zimport.h"
+#include "lib/layer.h"
+#include "lib/utils.h"
+
+
+/* Magic defaults for the worker. */
+#ifndef MP_FREELIST_SIZE
+# ifdef __clang_analyzer__
+# define MP_FREELIST_SIZE 0
+# else
+# define MP_FREELIST_SIZE 64 /**< Maximum length of the worker mempool freelist */
+# endif
+#endif
+#ifndef QUERY_RATE_THRESHOLD
+#define QUERY_RATE_THRESHOLD (2 * MP_FREELIST_SIZE) /**< Nr of parallel queries considered as high rate */
+#endif
+#ifndef MAX_PIPELINED
+#define MAX_PIPELINED 100
+#endif
+
+#define VERBOSE_MSG(qry, ...) QRVERBOSE(qry, "wrkr", __VA_ARGS__)
+
+/** Client request state. */
+struct request_ctx
+{
+ struct kr_request req;
+
+ struct {
+ /** Requestor's address; separate because of UDP session "sharing". */
+ union inaddr addr;
+ /** NULL if the request didn't come over network. */
+ struct session *session;
+ } source;
+
+ struct worker_ctx *worker;
+ struct qr_task *task;
+};
+
+/** Query resolution task. */
+struct qr_task
+{
+ struct request_ctx *ctx;
+ knot_pkt_t *pktbuf;
+ qr_tasklist_t waiting;
+ struct session *pending[MAX_PENDING];
+ uint16_t pending_count;
+ uint16_t addrlist_count;
+ uint16_t addrlist_turn;
+ uint16_t timeouts;
+ uint16_t iter_count;
+ struct sockaddr *addrlist;
+ uint32_t refs;
+ bool finished : 1;
+ bool leading : 1;
+ uint64_t creation_time;
+};
+
+
+/* Convenience macros */
+#define qr_task_ref(task) \
+ do { ++(task)->refs; } while(0)
+#define qr_task_unref(task) \
+ do { \
+ if (task) \
+ assert((task)->refs > 0); \
+ if ((task) && --(task)->refs == 0) \
+ qr_task_free((task)); \
+ } while (0)
+
+/** @internal get key for tcp session
+ * @note kr_straddr() return pointer to static string
+ */
+#define tcpsess_key(addr) kr_straddr(addr)
+
+/* Forward decls */
+static void qr_task_free(struct qr_task *task);
+static int qr_task_step(struct qr_task *task,
+ const struct sockaddr *packet_source,
+ knot_pkt_t *packet);
+static int qr_task_send(struct qr_task *task, struct session *session,
+ const struct sockaddr *addr, knot_pkt_t *pkt);
+static int qr_task_finalize(struct qr_task *task, int state);
+static void qr_task_complete(struct qr_task *task);
+static struct session* worker_find_tcp_connected(struct worker_ctx *worker,
+ const struct sockaddr *addr);
+static int worker_add_tcp_waiting(struct worker_ctx *worker,
+ const struct sockaddr *addr,
+ struct session *session);
+static struct session* worker_find_tcp_waiting(struct worker_ctx *worker,
+ const struct sockaddr *addr);
+static void on_tcp_connect_timeout(uv_timer_t *timer);
+
+struct worker_ctx the_worker_value; /**< Static allocation is suitable for the singleton. */
+struct worker_ctx *the_worker = NULL;
+
+/*! @internal Create a UDP/TCP handle for an outgoing AF_INET* connection.
+ * socktype is SOCK_* */
+static uv_handle_t *ioreq_spawn(struct worker_ctx *worker,
+ int socktype, sa_family_t family, bool has_tls)
+{
+ bool precond = (socktype == SOCK_DGRAM || socktype == SOCK_STREAM)
+ && (family == AF_INET || family == AF_INET6);
+ if (!precond) {
+ assert(false);
+ kr_log_verbose("[work] ioreq_spawn: pre-condition failed\n");
+ return NULL;
+ }
+
+ /* Create connection for iterative query */
+ uv_handle_t *handle = malloc(socktype == SOCK_DGRAM
+ ? sizeof(uv_udp_t) : sizeof(uv_tcp_t));
+ if (!handle) {
+ return NULL;
+ }
+ int ret = io_create(worker->loop, handle, socktype, family, has_tls);
+ if (ret) {
+ if (ret == UV_EMFILE) {
+ worker->too_many_open = true;
+ worker->rconcurrent_highwatermark = worker->stats.rconcurrent;
+ }
+ free(handle);
+ return NULL;
+ }
+
+ /* Bind to outgoing address, according to IP v4/v6. */
+ union inaddr *addr;
+ if (family == AF_INET) {
+ addr = (union inaddr *)&worker->out_addr4;
+ } else {
+ addr = (union inaddr *)&worker->out_addr6;
+ }
+ if (addr->ip.sa_family != AF_UNSPEC) {
+ assert(addr->ip.sa_family == family);
+ if (socktype == SOCK_DGRAM) {
+ uv_udp_t *udp = (uv_udp_t *)handle;
+ ret = uv_udp_bind(udp, &addr->ip, 0);
+ } else if (socktype == SOCK_STREAM){
+ uv_tcp_t *tcp = (uv_tcp_t *)handle;
+ ret = uv_tcp_bind(tcp, &addr->ip, 0);
+ }
+ }
+
+ if (ret != 0) {
+ io_deinit(handle);
+ free(handle);
+ return NULL;
+ }
+
+ /* Set current handle as a subrequest type. */
+ struct session *session = handle->data;
+ session_flags(session)->outgoing = true;
+ /* Connect or issue query datagram */
+ return handle;
+}
+
+static void ioreq_kill_pending(struct qr_task *task)
+{
+ for (uint16_t i = 0; i < task->pending_count; ++i) {
+ session_kill_ioreq(task->pending[i], task);
+ }
+ task->pending_count = 0;
+}
+
+/** @cond This memory layout is internal to mempool.c, use only for debugging. */
+#if defined(__SANITIZE_ADDRESS__)
+struct mempool_chunk {
+ struct mempool_chunk *next;
+ size_t size;
+};
+static void mp_poison(struct mempool *mp, bool poison)
+{
+ if (!poison) { /* @note mempool is part of the first chunk, unpoison it first */
+ kr_asan_unpoison(mp, sizeof(*mp));
+ }
+ struct mempool_chunk *chunk = mp->state.last[0];
+ void *chunk_off = (uint8_t *)chunk - chunk->size;
+ if (poison) {
+ kr_asan_poison(chunk_off, chunk->size);
+ } else {
+ kr_asan_unpoison(chunk_off, chunk->size);
+ }
+}
+#else
+#define mp_poison(mp, enable)
+#endif
+/** @endcond */
+
+/** Get a mempool. (Recycle if possible.) */
+static inline struct mempool *pool_borrow(struct worker_ctx *worker)
+{
+ struct mempool *mp = NULL;
+ if (worker->pool_mp.len > 0) {
+ mp = array_tail(worker->pool_mp);
+ array_pop(worker->pool_mp);
+ mp_poison(mp, 0);
+ } else { /* No mempool on the freelist, create new one */
+ mp = mp_new (4 * CPU_PAGE_SIZE);
+ }
+ return mp;
+}
+
+/** Return a mempool. (Cache them up to some count.) */
+static inline void pool_release(struct worker_ctx *worker, struct mempool *mp)
+{
+ if (worker->pool_mp.len < MP_FREELIST_SIZE) {
+ mp_flush(mp);
+ array_push(worker->pool_mp, mp);
+ mp_poison(mp, 1);
+ } else {
+ mp_delete(mp);
+ }
+}
+
+/** Create a key for an outgoing subrequest: qname, qclass, qtype.
+ * @param key Destination buffer for key size, MUST be SUBREQ_KEY_LEN or larger.
+ * @return key length if successful or an error
+ */
+static const size_t SUBREQ_KEY_LEN = KR_RRKEY_LEN;
+static int subreq_key(char *dst, knot_pkt_t *pkt)
+{
+ assert(pkt);
+ return kr_rrkey(dst, knot_pkt_qclass(pkt), knot_pkt_qname(pkt),
+ knot_pkt_qtype(pkt), knot_pkt_qtype(pkt));
+}
+
+/** Create and initialize a request_ctx (on a fresh mempool).
+ *
+ * handle and addr point to the source of the request, and they are NULL
+ * in case the request didn't come from network.
+ */
+static struct request_ctx *request_create(struct worker_ctx *worker,
+ struct session *session,
+ const struct sockaddr *peer,
+ uint32_t uid)
+{
+ knot_mm_t pool = {
+ .ctx = pool_borrow(worker),
+ .alloc = (knot_mm_alloc_t) mp_alloc
+ };
+
+ /* Create request context */
+ struct request_ctx *ctx = mm_alloc(&pool, sizeof(*ctx));
+ if (!ctx) {
+ pool_release(worker, pool.ctx);
+ return NULL;
+ }
+
+ memset(ctx, 0, sizeof(*ctx));
+
+ /* TODO Relocate pool to struct request */
+ ctx->worker = worker;
+ if (session) {
+ assert(session_flags(session)->outgoing == false);
+ }
+ ctx->source.session = session;
+
+ struct kr_request *req = &ctx->req;
+ req->pool = pool;
+ req->vars_ref = LUA_NOREF;
+ req->uid = uid;
+ if (session) {
+ /* We assume the session will be alive during the whole life of the request. */
+ req->qsource.dst_addr = session_get_sockname(session);
+ req->qsource.flags.tcp = session_get_handle(session)->type == UV_TCP;
+ req->qsource.flags.tls = session_flags(session)->has_tls;
+ /* We need to store a copy of peer address. */
+ memcpy(&ctx->source.addr.ip, peer, kr_sockaddr_len(peer));
+ req->qsource.addr = &ctx->source.addr.ip;
+ }
+
+ worker->stats.rconcurrent += 1;
+
+ return ctx;
+}
+
+/** More initialization, related to the particular incoming query/packet. */
+static int request_start(struct request_ctx *ctx, knot_pkt_t *query)
+{
+ assert(query && ctx);
+ size_t answer_max = KNOT_WIRE_MIN_PKTSIZE;
+ struct kr_request *req = &ctx->req;
+
+ /* source.session can be empty if request was generated by kresd itself */
+ struct session *s = ctx->source.session;
+ if (!s || session_get_handle(s)->type == UV_TCP) {
+ answer_max = KNOT_WIRE_MAX_PKTSIZE;
+ } else if (knot_pkt_has_edns(query)) { /* EDNS */
+ answer_max = MAX(knot_edns_get_payload(query->opt_rr),
+ KNOT_WIRE_MIN_PKTSIZE);
+ }
+ req->qsource.size = query->size;
+ if (knot_pkt_has_tsig(query)) {
+ req->qsource.size += query->tsig_wire.len;
+ }
+
+ knot_pkt_t *answer = knot_pkt_new(NULL, answer_max, &req->pool);
+ if (!answer) { /* Failed to allocate answer */
+ return kr_error(ENOMEM);
+ }
+
+ knot_pkt_t *pkt = knot_pkt_new(NULL, req->qsource.size, &req->pool);
+ if (!pkt) {
+ return kr_error(ENOMEM);
+ }
+
+ int ret = knot_pkt_copy(pkt, query);
+ if (ret != KNOT_EOK && ret != KNOT_ETRAIL) {
+ return kr_error(ENOMEM);
+ }
+ req->qsource.packet = pkt;
+
+ /* Start resolution */
+ struct worker_ctx *worker = ctx->worker;
+ struct engine *engine = worker->engine;
+ kr_resolve_begin(req, &engine->resolver, answer);
+ worker->stats.queries += 1;
+ /* Throttle outbound queries only when high pressure */
+ if (worker->stats.concurrent < QUERY_RATE_THRESHOLD) {
+ req->options.NO_THROTTLE = true;
+ }
+ return kr_ok();
+}
+
+static void request_free(struct request_ctx *ctx)
+{
+ struct worker_ctx *worker = ctx->worker;
+ /* Dereference any Lua vars table if exists */
+ if (ctx->req.vars_ref != LUA_NOREF) {
+ lua_State *L = worker->engine->L;
+ /* Get worker variables table */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, worker->vars_table_ref);
+ /* Get next free element (position 0) and store it under current reference (forming a list) */
+ lua_rawgeti(L, -1, 0);
+ lua_rawseti(L, -2, ctx->req.vars_ref);
+ /* Set current reference as the next free element */
+ lua_pushinteger(L, ctx->req.vars_ref);
+ lua_rawseti(L, -2, 0);
+ lua_pop(L, 1);
+ ctx->req.vars_ref = LUA_NOREF;
+ }
+ /* Return mempool to ring or free it if it's full */
+ pool_release(worker, ctx->req.pool.ctx);
+ /* @note The 'task' is invalidated from now on. */
+ worker->stats.rconcurrent -= 1;
+}
+
+static struct qr_task *qr_task_create(struct request_ctx *ctx)
+{
+ /* How much can client handle? */
+ struct engine *engine = ctx->worker->engine;
+ size_t pktbuf_max = KR_EDNS_PAYLOAD;
+ if (engine->resolver.opt_rr) {
+ pktbuf_max = MAX(knot_edns_get_payload(engine->resolver.opt_rr),
+ pktbuf_max);
+ }
+
+ /* Create resolution task */
+ struct qr_task *task = mm_alloc(&ctx->req.pool, sizeof(*task));
+ if (!task) {
+ return NULL;
+ }
+ memset(task, 0, sizeof(*task)); /* avoid accidentally unintialized fields */
+
+ /* Create packet buffers for answer and subrequests */
+ knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &ctx->req.pool);
+ if (!pktbuf) {
+ mm_free(&ctx->req.pool, task);
+ return NULL;
+ }
+ pktbuf->size = 0;
+
+ task->ctx = ctx;
+ task->pktbuf = pktbuf;
+ array_init(task->waiting);
+ task->refs = 0;
+ assert(ctx->task == NULL);
+ ctx->task = task;
+ /* Make the primary reference to task. */
+ qr_task_ref(task);
+ task->creation_time = kr_now();
+ ctx->worker->stats.concurrent += 1;
+ return task;
+}
+
+/* This is called when the task refcount is zero, free memory. */
+static void qr_task_free(struct qr_task *task)
+{
+ struct request_ctx *ctx = task->ctx;
+
+ assert(ctx);
+
+ struct worker_ctx *worker = ctx->worker;
+
+ if (ctx->task == NULL) {
+ request_free(ctx);
+ }
+
+ /* Update stats */
+ worker->stats.concurrent -= 1;
+}
+
+/*@ Register new qr_task within session. */
+static int qr_task_register(struct qr_task *task, struct session *session)
+{
+ assert(!session_flags(session)->outgoing && session_get_handle(session)->type == UV_TCP);
+
+ session_tasklist_add(session, task);
+
+ struct request_ctx *ctx = task->ctx;
+ assert(ctx && (ctx->source.session == NULL || ctx->source.session == session));
+ ctx->source.session = session;
+ /* Soft-limit on parallel queries, there is no "slow down" RCODE
+ * that we could use to signalize to client, but we can stop reading,
+ * an in effect shrink TCP window size. To get more precise throttling,
+ * we would need to copy remainder of the unread buffer and reassemble
+ * when resuming reading. This is NYI. */
+ if (session_tasklist_get_len(session) >= task->ctx->worker->tcp_pipeline_max &&
+ !session_flags(session)->throttled && !session_flags(session)->closing) {
+ session_stop_read(session);
+ session_flags(session)->throttled = true;
+ }
+
+ return 0;
+}
+
+static void qr_task_complete(struct qr_task *task)
+{
+ struct request_ctx *ctx = task->ctx;
+
+ /* Kill pending I/O requests */
+ ioreq_kill_pending(task);
+ assert(task->waiting.len == 0);
+ assert(task->leading == false);
+
+ struct session *s = ctx->source.session;
+ if (s) {
+ assert(!session_flags(s)->outgoing && session_waitinglist_is_empty(s));
+ ctx->source.session = NULL;
+ session_tasklist_del(s, task);
+ }
+
+ /* Release primary reference to task. */
+ if (ctx->task == task) {
+ ctx->task = NULL;
+ qr_task_unref(task);
+ }
+}
+
+/* This is called when we send subrequest / answer */
+int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status)
+{
+
+ if (task->finished) {
+ assert(task->leading == false);
+ qr_task_complete(task);
+ }
+
+ if (!handle || handle->type != UV_TCP) {
+ return status;
+ }
+
+ struct session* s = handle->data;
+ assert(s);
+ if (status != 0) {
+ session_tasklist_del(s, task);
+ }
+
+ if (session_flags(s)->outgoing || session_flags(s)->closing) {
+ return status;
+ }
+
+ struct worker_ctx *worker = task->ctx->worker;
+ if (session_flags(s)->throttled &&
+ session_tasklist_get_len(s) < worker->tcp_pipeline_max/2) {
+ /* Start reading again if the session is throttled and
+ * the number of outgoing requests is below watermark. */
+ session_start_read(s);
+ session_flags(s)->throttled = false;
+ }
+
+ return status;
+}
+
+static void on_send(uv_udp_send_t *req, int status)
+{
+ struct qr_task *task = req->data;
+ uv_handle_t *h = (uv_handle_t *)req->handle;
+ qr_task_on_send(task, h, status);
+ qr_task_unref(task);
+ free(req);
+}
+
+static void on_write(uv_write_t *req, int status)
+{
+ struct qr_task *task = req->data;
+ uv_handle_t *h = (uv_handle_t *)req->handle;
+ qr_task_on_send(task, h, status);
+ qr_task_unref(task);
+ free(req);
+}
+
+static int qr_task_send(struct qr_task *task, struct session *session,
+ const struct sockaddr *addr, knot_pkt_t *pkt)
+{
+ if (!session) {
+ return qr_task_on_send(task, NULL, kr_error(EIO));
+ }
+
+ int ret = 0;
+ struct request_ctx *ctx = task->ctx;
+
+ uv_handle_t *handle = session_get_handle(session);
+ assert(handle && handle->data == session);
+ const bool is_stream = handle->type == UV_TCP;
+ if (!is_stream && handle->type != UV_UDP) abort();
+
+ if (addr == NULL) {
+ addr = session_get_peer(session);
+ }
+
+ if (pkt == NULL) {
+ pkt = worker_task_get_pktbuf(task);
+ }
+
+ if (session_flags(session)->outgoing && handle->type == UV_TCP) {
+ size_t try_limit = session_tasklist_get_len(session) + 1;
+ uint16_t msg_id = knot_wire_get_id(pkt->wire);
+ size_t try_count = 0;
+ while (session_tasklist_find_msgid(session, msg_id) &&
+ try_count <= try_limit) {
+ ++msg_id;
+ ++try_count;
+ }
+ if (try_count > try_limit) {
+ return kr_error(ENOENT);
+ }
+ worker_task_pkt_set_msgid(task, msg_id);
+ }
+
+ uv_handle_t *ioreq = malloc(is_stream ? sizeof(uv_write_t) : sizeof(uv_udp_send_t));
+ if (!ioreq) {
+ return qr_task_on_send(task, handle, kr_error(ENOMEM));
+ }
+
+ /* Pending ioreq on current task */
+ qr_task_ref(task);
+
+ struct worker_ctx *worker = ctx->worker;
+ /* Send using given protocol */
+ assert(!session_flags(session)->closing);
+ if (session_flags(session)->has_tls) {
+ uv_write_t *write_req = (uv_write_t *)ioreq;
+ write_req->data = task;
+ ret = tls_write(write_req, handle, pkt, &on_write);
+ } else if (handle->type == UV_UDP) {
+ uv_udp_send_t *send_req = (uv_udp_send_t *)ioreq;
+ uv_buf_t buf = { (char *)pkt->wire, pkt->size };
+ send_req->data = task;
+ ret = uv_udp_send(send_req, (uv_udp_t *)handle, &buf, 1, addr, &on_send);
+ } else if (handle->type == UV_TCP) {
+ uv_write_t *write_req = (uv_write_t *)ioreq;
+ /* We need to write message length in native byte order,
+ * but we don't have a convenient place to store those bytes.
+ * The problem is that all memory referenced from buf[] MUST retain
+ * its contents at least until on_write() is called, and I currently
+ * can't see any convenient place outside the `pkt` structure.
+ * So we use directly the *individual* bytes in pkt->size.
+ * The call to htonl() and the condition will probably be inlinable. */
+ int lsbi, slsbi; /* (second) least significant byte index */
+ if (htonl(1) == 1) { /* big endian */
+ lsbi = sizeof(pkt->size) - 1;
+ slsbi = sizeof(pkt->size) - 2;
+ } else {
+ lsbi = 0;
+ slsbi = 1;
+ }
+ uv_buf_t buf[3] = {
+ { (char *)&pkt->size + slsbi, 1 },
+ { (char *)&pkt->size + lsbi, 1 },
+ { (char *)pkt->wire, pkt->size },
+ };
+ write_req->data = task;
+ ret = uv_write(write_req, (uv_stream_t *)handle, buf, 3, &on_write);
+ } else {
+ assert(false);
+ }
+
+ if (ret == 0) {
+ session_touch(session);
+ if (session_flags(session)->outgoing) {
+ session_tasklist_add(session, task);
+ }
+ if (worker->too_many_open &&
+ worker->stats.rconcurrent <
+ worker->rconcurrent_highwatermark - 10) {
+ worker->too_many_open = false;
+ }
+ } else {
+ free(ioreq);
+ qr_task_unref(task);
+ if (ret == UV_EMFILE) {
+ worker->too_many_open = true;
+ worker->rconcurrent_highwatermark = worker->stats.rconcurrent;
+ ret = kr_error(UV_EMFILE);
+ }
+ }
+
+ /* Update statistics */
+ if (session_flags(session)->outgoing && addr) {
+ if (session_flags(session)->has_tls)
+ worker->stats.tls += 1;
+ else if (handle->type == UV_UDP)
+ worker->stats.udp += 1;
+ else
+ worker->stats.tcp += 1;
+
+ if (addr->sa_family == AF_INET6)
+ worker->stats.ipv6 += 1;
+ else if (addr->sa_family == AF_INET)
+ worker->stats.ipv4 += 1;
+ }
+ return ret;
+}
+
+static struct kr_query *task_get_last_pending_query(struct qr_task *task)
+{
+ if (!task || task->ctx->req.rplan.pending.len == 0) {
+ return NULL;
+ }
+
+ return array_tail(task->ctx->req.rplan.pending);
+}
+
+static int session_tls_hs_cb(struct session *session, int status)
+{
+ assert(session_flags(session)->outgoing);
+ uv_handle_t *handle = session_get_handle(session);
+ uv_loop_t *loop = handle->loop;
+ struct worker_ctx *worker = loop->data;
+ struct sockaddr *peer = session_get_peer(session);
+ int deletion_res = worker_del_tcp_waiting(worker, peer);
+ int ret = kr_ok();
+
+ if (status) {
+ struct qr_task *task = session_waitinglist_get(session);
+ if (task) {
+ struct kr_qflags *options = &task->ctx->req.options;
+ unsigned score = options->FORWARD || options->STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD;
+ kr_nsrep_update_rtt(NULL, peer, score,
+ worker->engine->resolver.cache_rtt,
+ KR_NS_UPDATE_NORESET);
+ }
+#ifndef NDEBUG
+ else {
+ /* Task isn't in the list of tasks
+ * waiting for connection to upstream.
+ * So that it MUST be unsuccessful rehandshake.
+ * Check it. */
+ assert(deletion_res != 0);
+ const char *key = tcpsess_key(peer);
+ assert(key);
+ assert(map_contains(&worker->tcp_connected, key) != 0);
+ }
+#endif
+ return ret;
+ }
+
+ /* handshake was completed successfully */
+ struct tls_client_ctx_t *tls_client_ctx = session_tls_get_client_ctx(session);
+ tls_client_param_t *tls_params = tls_client_ctx->params;
+ gnutls_session_t tls_session = tls_client_ctx->c.tls_session;
+ if (gnutls_session_is_resumed(tls_session) != 0) {
+ kr_log_verbose("[tls_client] TLS session has resumed\n");
+ } else {
+ kr_log_verbose("[tls_client] TLS session has not resumed\n");
+ /* session wasn't resumed, delete old session data ... */
+ if (tls_params->session_data.data != NULL) {
+ gnutls_free(tls_params->session_data.data);
+ tls_params->session_data.data = NULL;
+ tls_params->session_data.size = 0;
+ }
+ /* ... and get the new session data */
+ gnutls_datum_t tls_session_data = { NULL, 0 };
+ ret = gnutls_session_get_data2(tls_session, &tls_session_data);
+ if (ret == 0) {
+ tls_params->session_data = tls_session_data;
+ }
+ }
+
+ struct session *s = worker_find_tcp_connected(worker, peer);
+ ret = kr_ok();
+ if (deletion_res == kr_ok()) {
+ /* peer was in the waiting list, add to the connected list. */
+ if (s) {
+ /* Something went wrong,
+ * peer already is in the connected list. */
+ ret = kr_error(EINVAL);
+ } else {
+ ret = worker_add_tcp_connected(worker, peer, session);
+ }
+ } else {
+ /* peer wasn't in the waiting list.
+ * It can be
+ * 1) either successful rehandshake; in this case peer
+ * must be already in the connected list.
+ * 2) or successful handshake with session, which was timeouted
+ * by on_tcp_connect_timeout(); after successful tcp connection;
+ * in this case peer isn't in the connected list.
+ **/
+ if (!s || s != session) {
+ ret = kr_error(EINVAL);
+ }
+ }
+ if (ret == kr_ok()) {
+ while (!session_waitinglist_is_empty(session)) {
+ struct qr_task *t = session_waitinglist_get(session);
+ ret = qr_task_send(t, session, NULL, NULL);
+ if (ret != 0) {
+ break;
+ }
+ session_waitinglist_pop(session, true);
+ }
+ } else {
+ ret = kr_error(EINVAL);
+ }
+
+ if (ret != kr_ok()) {
+ /* Something went wrong.
+ * Either addition to the list of connected sessions
+ * or write to upstream failed. */
+ worker_del_tcp_connected(worker, peer);
+ session_waitinglist_finalize(session, KR_STATE_FAIL);
+ assert(session_tasklist_is_empty(session));
+ session_close(session);
+ } else {
+ session_timer_stop(session);
+ session_timer_start(session, tcp_timeout_trigger,
+ MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
+ }
+ return kr_ok();
+}
+
+static int send_waiting(struct session *session)
+{
+ int ret = 0;
+ while (!session_waitinglist_is_empty(session)) {
+ struct qr_task *t = session_waitinglist_get(session);
+ ret = qr_task_send(t, session, NULL, NULL);
+ if (ret != 0) {
+ struct worker_ctx *worker = t->ctx->worker;
+ struct sockaddr *peer = session_get_peer(session);
+ session_waitinglist_finalize(session, KR_STATE_FAIL);
+ session_tasklist_finalize(session, KR_STATE_FAIL);
+ worker_del_tcp_connected(worker, peer);
+ session_close(session);
+ break;
+ }
+ session_waitinglist_pop(session, true);
+ }
+ return ret;
+}
+
+static void on_connect(uv_connect_t *req, int status)
+{
+ struct worker_ctx *worker = the_worker;
+ assert(worker);
+ uv_stream_t *handle = req->handle;
+ struct session *session = handle->data;
+ struct sockaddr *peer = session_get_peer(session);
+ free(req);
+
+ assert(session_flags(session)->outgoing);
+
+ if (session_flags(session)->closing) {
+ worker_del_tcp_waiting(worker, peer);
+ assert(session_is_empty(session));
+ return;
+ }
+
+ /* Check if the connection is in the waiting list.
+ * If no, most likely this is timeouted connection
+ * which was removed from waiting list by
+ * on_tcp_connect_timeout() callback. */
+ struct session *s = worker_find_tcp_waiting(worker, peer);
+ if (!s || s != session) {
+ /* session isn't on the waiting list.
+ * it's timeouted session. */
+ if (VERBOSE_STATUS) {
+ const char *peer_str = kr_straddr(peer);
+ kr_log_verbose( "[wrkr]=> connected to '%s', but session "
+ "is already timeouted, close\n",
+ peer_str ? peer_str : "");
+ }
+ assert(session_tasklist_is_empty(session));
+ session_waitinglist_retry(session, false);
+ session_close(session);
+ return;
+ }
+
+ s = worker_find_tcp_connected(worker, peer);
+ if (s) {
+ /* session already in the connected list.
+ * Something went wrong, it can be due to races when kresd has tried
+ * to reconnect to upstream after unsuccessful attempt. */
+ if (VERBOSE_STATUS) {
+ const char *peer_str = kr_straddr(peer);
+ kr_log_verbose( "[wrkr]=> connected to '%s', but peer "
+ "is already connected, close\n",
+ peer_str ? peer_str : "");
+ }
+ assert(session_tasklist_is_empty(session));
+ session_waitinglist_retry(session, false);
+ session_close(session);
+ return;
+ }
+
+ if (status != 0) {
+ if (VERBOSE_STATUS) {
+ const char *peer_str = kr_straddr(peer);
+ kr_log_verbose( "[wrkr]=> connection to '%s' failed (%s), flagged as 'bad'\n",
+ peer_str ? peer_str : "", uv_strerror(status));
+ }
+ worker_del_tcp_waiting(worker, peer);
+ struct qr_task *task = session_waitinglist_get(session);
+ if (task && status != UV_ETIMEDOUT) {
+ /* Penalize upstream.
+ * In case of UV_ETIMEDOUT upstream has been
+ * already penalized in on_tcp_connect_timeout() */
+ struct kr_qflags *options = &task->ctx->req.options;
+ unsigned score = options->FORWARD || options->STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD;
+ kr_nsrep_update_rtt(NULL, peer, score,
+ worker->engine->resolver.cache_rtt,
+ KR_NS_UPDATE_NORESET);
+ }
+ assert(session_tasklist_is_empty(session));
+ session_waitinglist_retry(session, false);
+ session_close(session);
+ return;
+ }
+
+ if (!session_flags(session)->has_tls) {
+ /* if there is a TLS, session still waiting for handshake,
+ * otherwise remove it from waiting list */
+ if (worker_del_tcp_waiting(worker, peer) != 0) {
+ /* session isn't in list of waiting queries, *
+ * something gone wrong */
+ session_waitinglist_finalize(session, KR_STATE_FAIL);
+ assert(session_tasklist_is_empty(session));
+ session_close(session);
+ return;
+ }
+ }
+
+ if (VERBOSE_STATUS) {
+ const char *peer_str = kr_straddr(peer);
+ kr_log_verbose( "[wrkr]=> connected to '%s'\n", peer_str ? peer_str : "");
+ }
+
+ session_flags(session)->connected = true;
+ session_start_read(session);
+
+ int ret = kr_ok();
+ if (session_flags(session)->has_tls) {
+ struct tls_client_ctx_t *tls_ctx = session_tls_get_client_ctx(session);
+ ret = tls_client_connect_start(tls_ctx, session, session_tls_hs_cb);
+ if (ret == kr_error(EAGAIN)) {
+ session_timer_stop(session);
+ session_timer_start(session, tcp_timeout_trigger,
+ MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
+ return;
+ }
+ } else {
+ worker_add_tcp_connected(worker, peer, session);
+ }
+
+ ret = send_waiting(session);
+ if (ret != 0) {
+ return;
+ }
+
+ session_timer_stop(session);
+ session_timer_start(session, tcp_timeout_trigger,
+ MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
+}
+
+static void on_tcp_connect_timeout(uv_timer_t *timer)
+{
+ struct session *session = timer->data;
+
+ uv_timer_stop(timer);
+ struct worker_ctx *worker = the_worker;
+ assert(worker);
+
+ assert (session_tasklist_is_empty(session));
+
+ struct sockaddr *peer = session_get_peer(session);
+ worker_del_tcp_waiting(worker, peer);
+
+ struct qr_task *task = session_waitinglist_get(session);
+ if (!task) {
+ /* Normally shouldn't happen. */
+ const char *peer_str = kr_straddr(peer);
+ VERBOSE_MSG(NULL, "=> connection to '%s' failed (internal timeout), empty waitinglist\n",
+ peer_str ? peer_str : "");
+ return;
+ }
+
+ struct kr_query *qry = task_get_last_pending_query(task);
+ WITH_VERBOSE (qry) {
+ const char *peer_str = kr_straddr(peer);
+ VERBOSE_MSG(qry, "=> connection to '%s' failed (internal timeout)\n",
+ peer_str ? peer_str : "");
+ }
+
+ unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD;
+ kr_nsrep_update_rtt(NULL, peer, score,
+ worker->engine->resolver.cache_rtt,
+ KR_NS_UPDATE_NORESET);
+
+ worker->stats.timeout += session_waitinglist_get_len(session);
+ session_waitinglist_retry(session, true);
+ assert (session_tasklist_is_empty(session));
+ /* uv_cancel() doesn't support uv_connect_t request,
+ * so that we can't cancel it.
+ * There still exists possibility of successful connection
+ * for this request.
+ * So connection callback (on_connect()) must check
+ * if connection is in the list of waiting connection.
+ * If no, most likely this is timeouted connection even if
+ * it was successful. */
+}
+
+/* This is called when I/O timeouts */
+static void on_udp_timeout(uv_timer_t *timer)
+{
+ struct session *session = timer->data;
+ assert(session_get_handle(session)->data == session);
+ assert(session_tasklist_get_len(session) == 1);
+ assert(session_waitinglist_is_empty(session));
+
+ uv_timer_stop(timer);
+
+ /* Penalize all tried nameservers with a timeout. */
+ struct qr_task *task = session_tasklist_get_first(session);
+ struct worker_ctx *worker = task->ctx->worker;
+ if (task->leading && task->pending_count > 0) {
+ struct kr_query *qry = array_tail(task->ctx->req.rplan.pending);
+ struct sockaddr_in6 *addrlist = (struct sockaddr_in6 *)task->addrlist;
+ for (uint16_t i = 0; i < MIN(task->pending_count, task->addrlist_count); ++i) {
+ struct sockaddr *choice = (struct sockaddr *)(&addrlist[i]);
+ WITH_VERBOSE(qry) {
+ char *addr_str = kr_straddr(choice);
+ VERBOSE_MSG(qry, "=> server: '%s' flagged as 'bad'\n", addr_str ? addr_str : "");
+ }
+ unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD;
+ kr_nsrep_update_rtt(&qry->ns, choice, score,
+ worker->engine->resolver.cache_rtt,
+ KR_NS_UPDATE_NORESET);
+ }
+ }
+ task->timeouts += 1;
+ worker->stats.timeout += 1;
+ qr_task_step(task, NULL, NULL);
+}
+
+static uv_handle_t *retransmit(struct qr_task *task)
+{
+ uv_handle_t *ret = NULL;
+ if (task && task->addrlist && task->addrlist_count > 0) {
+ struct sockaddr_in6 *choice = &((struct sockaddr_in6 *)task->addrlist)[task->addrlist_turn];
+ if (!choice) {
+ return ret;
+ }
+ if (task->pending_count >= MAX_PENDING) {
+ return ret;
+ }
+ /* Checkout answer before sending it */
+ struct request_ctx *ctx = task->ctx;
+ if (kr_resolve_checkout(&ctx->req, NULL, (struct sockaddr *)choice, SOCK_DGRAM, task->pktbuf) != 0) {
+ return ret;
+ }
+ ret = ioreq_spawn(ctx->worker, SOCK_DGRAM, choice->sin6_family, false);
+ if (!ret) {
+ return ret;
+ }
+ struct sockaddr *addr = (struct sockaddr *)choice;
+ struct session *session = ret->data;
+ struct sockaddr *peer = session_get_peer(session);
+ assert (peer->sa_family == AF_UNSPEC && session_flags(session)->outgoing);
+ memcpy(peer, addr, kr_sockaddr_len(addr));
+ if (qr_task_send(task, session, (struct sockaddr *)choice,
+ task->pktbuf) != 0) {
+ session_close(session);
+ ret = NULL;
+ } else {
+ task->pending[task->pending_count] = session;
+ task->pending_count += 1;
+ task->addrlist_turn = (task->addrlist_turn + 1) %
+ task->addrlist_count; /* Round robin */
+ session_start_read(session); /* Start reading answer */
+ }
+ }
+ return ret;
+}
+
+static void on_retransmit(uv_timer_t *req)
+{
+ struct session *session = req->data;
+ assert(session_tasklist_get_len(session) == 1);
+
+ uv_timer_stop(req);
+ struct qr_task *task = session_tasklist_get_first(session);
+ if (retransmit(task) == NULL) {
+ /* Not possible to spawn request, start timeout timer with remaining deadline. */
+ struct kr_qflags *options = &task->ctx->req.options;
+ uint64_t timeout = options->FORWARD || options->STUB ? KR_NS_FWD_TIMEOUT / 2 :
+ KR_CONN_RTT_MAX - task->pending_count * KR_CONN_RETRY;
+ uv_timer_start(req, on_udp_timeout, timeout, 0);
+ } else {
+ uv_timer_start(req, on_retransmit, KR_CONN_RETRY, 0);
+ }
+}
+
+static void subreq_finalize(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *pkt)
+{
+ if (!task || task->finished) {
+ return;
+ }
+ /* Close pending timer */
+ ioreq_kill_pending(task);
+ /* Clear from outgoing table. */
+ if (!task->leading)
+ return;
+ char key[SUBREQ_KEY_LEN];
+ const int klen = subreq_key(key, task->pktbuf);
+ if (klen > 0) {
+ void *val_deleted;
+ int ret = trie_del(task->ctx->worker->subreq_out, key, klen, &val_deleted);
+ assert(ret == KNOT_EOK && val_deleted == task); (void)ret;
+ }
+ /* Notify waiting tasks. */
+ struct kr_query *leader_qry = array_tail(task->ctx->req.rplan.pending);
+ for (size_t i = task->waiting.len; i > 0; i--) {
+ struct qr_task *follower = task->waiting.at[i - 1];
+ /* Reuse MSGID and 0x20 secret */
+ if (follower->ctx->req.rplan.pending.len > 0) {
+ struct kr_query *qry = array_tail(follower->ctx->req.rplan.pending);
+ qry->id = leader_qry->id;
+ qry->secret = leader_qry->secret;
+ leader_qry->secret = 0; /* Next will be already decoded */
+ }
+ qr_task_step(follower, packet_source, pkt);
+ qr_task_unref(follower);
+ }
+ task->waiting.len = 0;
+ task->leading = false;
+}
+
+static void subreq_lead(struct qr_task *task)
+{
+ assert(task);
+ char key[SUBREQ_KEY_LEN];
+ const int klen = subreq_key(key, task->pktbuf);
+ if (klen < 0)
+ return;
+ struct qr_task **tvp = (struct qr_task **)
+ trie_get_ins(task->ctx->worker->subreq_out, key, klen);
+ if (unlikely(!tvp))
+ return; /*ENOMEM*/
+ if (unlikely(*tvp != NULL)) {
+ assert(false);
+ return;
+ }
+ *tvp = task;
+ task->leading = true;
+}
+
+static bool subreq_enqueue(struct qr_task *task)
+{
+ assert(task);
+ char key[SUBREQ_KEY_LEN];
+ const int klen = subreq_key(key, task->pktbuf);
+ if (klen < 0)
+ return false;
+ struct qr_task **leader = (struct qr_task **)
+ trie_get_try(task->ctx->worker->subreq_out, key, klen);
+ if (!leader /*ENOMEM*/ || !*leader)
+ return false;
+ /* Enqueue itself to leader for this subrequest. */
+ int ret = array_push_mm((*leader)->waiting, task,
+ kr_memreserve, &(*leader)->ctx->req.pool);
+ if (unlikely(ret < 0)) /*ENOMEM*/
+ return false;
+ qr_task_ref(task);
+ return true;
+}
+
+static int qr_task_finalize(struct qr_task *task, int state)
+{
+ assert(task && task->leading == false);
+ if (task->finished) {
+ return 0;
+ }
+ struct request_ctx *ctx = task->ctx;
+ struct session *source_session = ctx->source.session;
+ kr_resolve_finish(&ctx->req, state);
+
+ task->finished = true;
+ if (source_session == NULL) {
+ (void) qr_task_on_send(task, NULL, kr_error(EIO));
+ return state == KR_STATE_DONE ? 0 : kr_error(EIO);
+ }
+
+ /* Reference task as the callback handler can close it */
+ qr_task_ref(task);
+
+ /* Send back answer */
+ assert(!session_flags(source_session)->closing);
+ assert(ctx->source.addr.ip.sa_family != AF_UNSPEC);
+
+ int ret;
+ const uv_handle_t *src_handle = session_get_handle(source_session);
+ if (src_handle->type != UV_UDP && src_handle->type != UV_TCP) {
+ assert(false);
+ ret = kr_error(EINVAL);
+ } else if (src_handle->type == UV_UDP && ENABLE_SENDMMSG) {
+ int fd;
+ ret = uv_fileno(src_handle, &fd);
+ assert(!ret);
+ if (ret == 0) {
+ udp_queue_push(fd, &ctx->req, task);
+ }
+ } else {
+ ret = qr_task_send(task, source_session, &ctx->source.addr.ip, ctx->req.answer);
+ }
+
+ if (ret != kr_ok()) {
+ (void) qr_task_on_send(task, NULL, kr_error(EIO));
+ /* Since source session is erroneous detach all tasks. */
+ while (!session_tasklist_is_empty(source_session)) {
+ struct qr_task *t = session_tasklist_del_first(source_session, false);
+ struct request_ctx *c = t->ctx;
+ assert(c->source.session == source_session);
+ c->source.session = NULL;
+ /* Don't finalize them as there can be other tasks
+ * waiting for answer to this particular task.
+ * (ie. task->leading is true) */
+ worker_task_unref(t);
+ }
+ session_close(source_session);
+ }
+
+ qr_task_unref(task);
+
+ return state == KR_STATE_DONE ? 0 : kr_error(EIO);
+}
+
+static int udp_task_step(struct qr_task *task,
+ const struct sockaddr *packet_source, knot_pkt_t *packet)
+{
+ struct request_ctx *ctx = task->ctx;
+ struct kr_request *req = &ctx->req;
+
+ /* If there is already outgoing query, enqueue to it. */
+ if (subreq_enqueue(task)) {
+ return kr_ok(); /* Will be notified when outgoing query finishes. */
+ }
+ /* Start transmitting */
+ uv_handle_t *handle = retransmit(task);
+ if (handle == NULL) {
+ subreq_finalize(task, packet_source, packet);
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ /* Check current query NSLIST */
+ struct kr_query *qry = array_tail(req->rplan.pending);
+ assert(qry != NULL);
+ /* Retransmit at default interval, or more frequently if the mean
+ * RTT of the server is better. If the server is glued, use default rate. */
+ size_t timeout = qry->ns.score;
+ if (timeout > KR_NS_GLUED) {
+ /* We don't have information about variance in RTT, expect +10ms */
+ timeout = MIN(qry->ns.score + 10, KR_CONN_RETRY);
+ } else {
+ timeout = KR_CONN_RETRY;
+ }
+ /* Announce and start subrequest.
+ * @note Only UDP can lead I/O as it doesn't touch 'task->pktbuf' for reassembly.
+ */
+ subreq_lead(task);
+ struct session *session = handle->data;
+ assert(session_get_handle(session) == handle && (handle->type == UV_UDP));
+ int ret = session_timer_start(session, on_retransmit, timeout, 0);
+ /* Start next step with timeout, fatal if can't start a timer. */
+ if (ret != 0) {
+ subreq_finalize(task, packet_source, packet);
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ return kr_ok();
+}
+
+static int tcp_task_waiting_connection(struct session *session, struct qr_task *task)
+{
+ assert(session_flags(session)->outgoing);
+ if (session_flags(session)->closing) {
+ /* Something went wrong. Better answer with KR_STATE_FAIL.
+ * TODO: normally should not happen,
+ * consider possibility to transform this into
+ * assert(!session_flags(session)->closing). */
+ return kr_error(EINVAL);
+ }
+ /* Add task to the end of list of waiting tasks.
+ * It will be notified in on_connect() or qr_task_on_send(). */
+ int ret = session_waitinglist_push(session, task);
+ if (ret < 0) {
+ return kr_error(EINVAL);
+ }
+ return kr_ok();
+}
+
+static int tcp_task_existing_connection(struct session *session, struct qr_task *task)
+{
+ assert(session_flags(session)->outgoing);
+ struct request_ctx *ctx = task->ctx;
+ struct worker_ctx *worker = ctx->worker;
+
+ if (session_flags(session)->closing) {
+ /* Something went wrong. Better answer with KR_STATE_FAIL.
+ * TODO: normally should not happen,
+ * consider possibility to transform this into
+ * assert(!session_flags(session)->closing). */
+ return kr_error(EINVAL);
+ }
+
+ /* If there are any unsent queries, send it first. */
+ int ret = send_waiting(session);
+ if (ret != 0) {
+ return kr_error(EINVAL);
+ }
+
+ /* No unsent queries at that point. */
+ if (session_tasklist_get_len(session) >= worker->tcp_pipeline_max) {
+ /* Too many outstanding queries, answer with SERFVAIL, */
+ return kr_error(EINVAL);
+ }
+
+ /* Send query to upstream. */
+ ret = qr_task_send(task, session, NULL, NULL);
+ if (ret != 0) {
+ /* Error, finalize task with SERVFAIL and
+ * close connection to upstream. */
+ session_tasklist_finalize(session, KR_STATE_FAIL);
+ worker_del_tcp_connected(worker, session_get_peer(session));
+ session_close(session);
+ return kr_error(EINVAL);
+ }
+
+ return kr_ok();
+}
+
+static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr *addr)
+{
+ struct request_ctx *ctx = task->ctx;
+ struct worker_ctx *worker = ctx->worker;
+
+ /* Check if there must be TLS */
+ struct tls_client_ctx_t *tls_ctx = NULL;
+ struct network *net = &worker->engine->net;
+ tls_client_param_t *entry = tls_client_param_get(net->tls_client_params, addr);
+ if (entry) {
+ /* Address is configured to be used with TLS.
+ * We need to allocate auxiliary data structure. */
+ tls_ctx = tls_client_ctx_new(entry, worker);
+ if (!tls_ctx) {
+ return kr_error(EINVAL);
+ }
+ }
+
+ uv_connect_t *conn = malloc(sizeof(uv_connect_t));
+ if (!conn) {
+ tls_client_ctx_free(tls_ctx);
+ return kr_error(EINVAL);
+ }
+ bool has_tls = (tls_ctx != NULL);
+ uv_handle_t *client = ioreq_spawn(worker, SOCK_STREAM, addr->sa_family, has_tls);
+ if (!client) {
+ tls_client_ctx_free(tls_ctx);
+ free(conn);
+ return kr_error(EINVAL);
+ }
+ struct session *session = client->data;
+ assert(session_flags(session)->has_tls == has_tls);
+ if (has_tls) {
+ tls_client_ctx_set_session(tls_ctx, session);
+ session_tls_set_client_ctx(session, tls_ctx);
+ }
+
+ /* Add address to the waiting list.
+ * Now it "is waiting to be connected to." */
+ int ret = worker_add_tcp_waiting(worker, addr, session);
+ if (ret < 0) {
+ free(conn);
+ session_close(session);
+ return kr_error(EINVAL);
+ }
+
+ conn->data = session;
+ /* Store peer address for the session. */
+ struct sockaddr *peer = session_get_peer(session);
+ memcpy(peer, addr, kr_sockaddr_len(addr));
+
+ /* Start watchdog to catch eventual connection timeout. */
+ ret = session_timer_start(session, on_tcp_connect_timeout,
+ KR_CONN_RTT_MAX, 0);
+ if (ret != 0) {
+ worker_del_tcp_waiting(worker, addr);
+ free(conn);
+ session_close(session);
+ return kr_error(EINVAL);
+ }
+
+ struct kr_query *qry = task_get_last_pending_query(task);
+ WITH_VERBOSE (qry) {
+ const char *peer_str = kr_straddr(peer);
+ VERBOSE_MSG(qry, "=> connecting to: '%s'\n", peer_str ? peer_str : "");
+ }
+
+ /* Start connection process to upstream. */
+ ret = uv_tcp_connect(conn, (uv_tcp_t *)client, addr , on_connect);
+ if (ret != 0) {
+ session_timer_stop(session);
+ worker_del_tcp_waiting(worker, addr);
+ free(conn);
+ session_close(session);
+ unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD;
+ kr_nsrep_update_rtt(NULL, peer, score,
+ worker->engine->resolver.cache_rtt,
+ KR_NS_UPDATE_NORESET);
+ WITH_VERBOSE (qry) {
+ const char *peer_str = kr_straddr(peer);
+ kr_log_verbose( "[wrkr]=> connect to '%s' failed (%s), flagged as 'bad'\n",
+ peer_str ? peer_str : "", uv_strerror(ret));
+ }
+ return kr_error(EAGAIN);
+ }
+
+ /* Add task to the end of list of waiting tasks.
+ * Will be notified either in on_connect() or in qr_task_on_send(). */
+ ret = session_waitinglist_push(session, task);
+ if (ret < 0) {
+ session_timer_stop(session);
+ worker_del_tcp_waiting(worker, addr);
+ free(conn);
+ session_close(session);
+ return kr_error(EINVAL);
+ }
+
+ return kr_ok();
+}
+
+static int tcp_task_step(struct qr_task *task,
+ const struct sockaddr *packet_source, knot_pkt_t *packet)
+{
+ assert(task->pending_count == 0);
+
+ /* target */
+ const struct sockaddr *addr = task->addrlist;
+ if (addr->sa_family == AF_UNSPEC) {
+ /* Target isn't defined. Finalize task with SERVFAIL.
+ * Although task->pending_count is zero, there are can be followers,
+ * so we need to call subreq_finalize() to handle them properly. */
+ subreq_finalize(task, packet_source, packet);
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ /* Checkout task before connecting */
+ struct request_ctx *ctx = task->ctx;
+ if (kr_resolve_checkout(&ctx->req, NULL, (struct sockaddr *)addr,
+ SOCK_STREAM, task->pktbuf) != 0) {
+ subreq_finalize(task, packet_source, packet);
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ int ret;
+ struct session* session = NULL;
+ if ((session = worker_find_tcp_waiting(ctx->worker, addr)) != NULL) {
+ /* Connection is in the list of waiting connections.
+ * It means that connection establishing is coming right now. */
+ ret = tcp_task_waiting_connection(session, task);
+ } else if ((session = worker_find_tcp_connected(ctx->worker, addr)) != NULL) {
+ /* Connection has been already established. */
+ ret = tcp_task_existing_connection(session, task);
+ } else {
+ /* Make connection. */
+ ret = tcp_task_make_connection(task, addr);
+ }
+
+ if (ret != kr_ok()) {
+ subreq_finalize(task, addr, packet);
+ if (ret == kr_error(EAGAIN)) {
+ ret = qr_task_step(task, addr, NULL);
+ } else {
+ ret = qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ }
+
+ return ret;
+}
+
+static int qr_task_step(struct qr_task *task,
+ const struct sockaddr *packet_source, knot_pkt_t *packet)
+{
+ /* No more steps after we're finished. */
+ if (!task || task->finished) {
+ return kr_error(ESTALE);
+ }
+
+ /* Close pending I/O requests */
+ subreq_finalize(task, packet_source, packet);
+ if ((kr_now() - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) {
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+
+ /* Consume input and produce next query */
+ struct request_ctx *ctx = task->ctx;
+ assert(ctx);
+ struct kr_request *req = &ctx->req;
+ struct worker_ctx *worker = ctx->worker;
+ int sock_type = -1;
+ task->addrlist = NULL;
+ task->addrlist_count = 0;
+ task->addrlist_turn = 0;
+
+ if (worker->too_many_open) {
+ /* */
+ struct kr_rplan *rplan = &req->rplan;
+ if (worker->stats.rconcurrent <
+ worker->rconcurrent_highwatermark - 10) {
+ worker->too_many_open = false;
+ } else {
+ if (packet && kr_rplan_empty(rplan)) {
+ /* new query; TODO - make this detection more obvious */
+ kr_resolve_consume(req, packet_source, packet);
+ }
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ }
+
+ int state = kr_resolve_consume(req, packet_source, packet);
+ while (state == KR_STATE_PRODUCE) {
+ state = kr_resolve_produce(req, &task->addrlist,
+ &sock_type, task->pktbuf);
+ if (unlikely(++task->iter_count > KR_ITER_LIMIT ||
+ task->timeouts >= KR_TIMEOUT_LIMIT)) {
+ return qr_task_finalize(task, KR_STATE_FAIL);
+ }
+ }
+
+ /* We're done, no more iterations needed */
+ if (state & (KR_STATE_DONE|KR_STATE_FAIL)) {
+ return qr_task_finalize(task, state);
+ } else if (!task->addrlist || sock_type < 0) {
+ return qr_task_step(task, NULL, NULL);
+ }
+
+ /* Count available address choices */
+ struct sockaddr_in6 *choice = (struct sockaddr_in6 *)task->addrlist;
+ for (size_t i = 0; i < KR_NSREP_MAXADDR && choice->sin6_family != AF_UNSPEC; ++i) {
+ task->addrlist_count += 1;
+ choice += 1;
+ }
+
+ /* Upgrade to TLS if the upstream address is configured as DoT capable. */
+ if (task->addrlist_count > 0 && kr_inaddr_port(task->addrlist) == KR_DNS_PORT) {
+ /* TODO if there are multiple addresses (task->addrlist_count > 1)
+ * check all of them. */
+ struct network *net = &worker->engine->net;
+ /* task->addrlist has to contain TLS port before tls_client_param_get() call */
+ kr_inaddr_set_port(task->addrlist, KR_DNS_TLS_PORT);
+ tls_client_param_t *tls_entry =
+ tls_client_param_get(net->tls_client_params, task->addrlist);
+ if (tls_entry) {
+ packet_source = NULL;
+ sock_type = SOCK_STREAM;
+ /* TODO in this case in tcp_task_make_connection() will be performed
+ * redundant map_get() call. */
+ } else {
+ /* The function is fairly cheap, so we just change there and back. */
+ kr_inaddr_set_port(task->addrlist, KR_DNS_PORT);
+ }
+ }
+
+ int ret = 0;
+ if (sock_type == SOCK_DGRAM) {
+ /* Start fast retransmit with UDP. */
+ ret = udp_task_step(task, packet_source, packet);
+ } else {
+ /* TCP. Connect to upstream or send the query if connection already exists. */
+ assert (sock_type == SOCK_STREAM);
+ ret = tcp_task_step(task, packet_source, packet);
+ }
+ return ret;
+}
+
+static int parse_packet(knot_pkt_t *query)
+{
+ if (!query){
+ return kr_error(EINVAL);
+ }
+
+ /* Parse query packet. */
+ int ret = knot_pkt_parse(query, 0);
+ if (ret == KNOT_ETRAIL) {
+ /* Extra data after message end. */
+ ret = kr_error(EMSGSIZE);
+ } else if (ret != KNOT_EOK) {
+ /* Malformed query. */
+ ret = kr_error(EPROTO);
+ } else {
+ ret = kr_ok();
+ }
+
+ return ret;
+}
+
+int worker_submit(struct session *session, const struct sockaddr *peer, knot_pkt_t *query)
+{
+ if (!session) {
+ assert(false);
+ return kr_error(EINVAL);
+ }
+
+ uv_handle_t *handle = session_get_handle(session);
+ bool OK = handle && handle->loop->data;
+ if (!OK) {
+ assert(false);
+ return kr_error(EINVAL);
+ }
+
+ struct worker_ctx *worker = handle->loop->data;
+
+ /* Parse packet */
+ int ret = parse_packet(query);
+
+ const bool is_query = (knot_wire_get_qr(query->wire) == 0);
+ const bool is_outgoing = session_flags(session)->outgoing;
+ /* Ignore badly formed queries. */
+ if (!query ||
+ (ret != kr_ok() && ret != kr_error(EMSGSIZE)) ||
+ (is_query == is_outgoing)) {
+ if (query && !is_outgoing) worker->stats.dropped += 1;
+ return kr_error(EILSEQ);
+ }
+
+ /* Start new task on listening sockets,
+ * or resume if this is subrequest */
+ struct qr_task *task = NULL;
+ const struct sockaddr *addr = NULL;
+ if (!is_outgoing) { /* request from a client */
+ struct request_ctx *ctx = request_create(worker, session, peer,
+ knot_wire_get_id(query->wire));
+ if (!ctx) {
+ return kr_error(ENOMEM);
+ }
+
+ ret = request_start(ctx, query);
+ if (ret != 0) {
+ request_free(ctx);
+ return kr_error(ENOMEM);
+ }
+
+ task = qr_task_create(ctx);
+ if (!task) {
+ request_free(ctx);
+ return kr_error(ENOMEM);
+ }
+
+ if (handle->type == UV_TCP && qr_task_register(task, session)) {
+ return kr_error(ENOMEM);
+ }
+ } else if (query) { /* response from upstream */
+ const uint16_t id = knot_wire_get_id(query->wire);
+ task = session_tasklist_del_msgid(session, id);
+ if (task == NULL) {
+ VERBOSE_MSG(NULL, "=> ignoring packet with mismatching ID %d\n",
+ (int)id);
+ return kr_error(ENOENT);
+ }
+ assert(!session_flags(session)->closing);
+ addr = peer;
+ }
+ assert(uv_is_closing(session_get_handle(session)) == false);
+
+ /* Packet was successfully parsed.
+ * Task was created (found). */
+ session_touch(session);
+ /* Consume input and produce next message */
+ return qr_task_step(task, addr, query);
+}
+
+static int map_add_tcp_session(map_t *map, const struct sockaddr* addr,
+ struct session *session)
+{
+ assert(map && addr);
+ const char *key = tcpsess_key(addr);
+ assert(key);
+ assert(map_contains(map, key) == 0);
+ int ret = map_set(map, key, session);
+ return ret ? kr_error(EINVAL) : kr_ok();
+}
+
+static int map_del_tcp_session(map_t *map, const struct sockaddr* addr)
+{
+ assert(map && addr);
+ const char *key = tcpsess_key(addr);
+ assert(key);
+ int ret = map_del(map, key);
+ return ret ? kr_error(ENOENT) : kr_ok();
+}
+
+static struct session* map_find_tcp_session(map_t *map,
+ const struct sockaddr *addr)
+{
+ assert(map && addr);
+ const char *key = tcpsess_key(addr);
+ assert(key);
+ struct session* ret = map_get(map, key);
+ return ret;
+}
+
+int worker_add_tcp_connected(struct worker_ctx *worker,
+ const struct sockaddr* addr,
+ struct session *session)
+{
+#ifndef NDEBUG
+ assert(addr);
+ const char *key = tcpsess_key(addr);
+ assert(key);
+ assert(map_contains(&worker->tcp_connected, key) == 0);
+#endif
+ return map_add_tcp_session(&worker->tcp_connected, addr, session);
+}
+
+int worker_del_tcp_connected(struct worker_ctx *worker,
+ const struct sockaddr* addr)
+{
+ assert(addr && tcpsess_key(addr));
+ return map_del_tcp_session(&worker->tcp_connected, addr);
+}
+
+static struct session* worker_find_tcp_connected(struct worker_ctx *worker,
+ const struct sockaddr* addr)
+{
+ return map_find_tcp_session(&worker->tcp_connected, addr);
+}
+
+static int worker_add_tcp_waiting(struct worker_ctx *worker,
+ const struct sockaddr* addr,
+ struct session *session)
+{
+#ifndef NDEBUG
+ assert(addr);
+ const char *key = tcpsess_key(addr);
+ assert(key);
+ assert(map_contains(&worker->tcp_waiting, key) == 0);
+#endif
+ return map_add_tcp_session(&worker->tcp_waiting, addr, session);
+}
+
+int worker_del_tcp_waiting(struct worker_ctx *worker,
+ const struct sockaddr* addr)
+{
+ assert(addr && tcpsess_key(addr));
+ return map_del_tcp_session(&worker->tcp_waiting, addr);
+}
+
+static struct session* worker_find_tcp_waiting(struct worker_ctx *worker,
+ const struct sockaddr* addr)
+{
+ return map_find_tcp_session(&worker->tcp_waiting, addr);
+}
+
+int worker_end_tcp(struct session *session)
+{
+ if (!session) {
+ return kr_error(EINVAL);
+ }
+
+ session_timer_stop(session);
+
+ uv_handle_t *handle = session_get_handle(session);
+ struct worker_ctx *worker = handle->loop->data;
+ struct sockaddr *peer = session_get_peer(session);
+
+ worker_del_tcp_waiting(worker, peer);
+ worker_del_tcp_connected(worker, peer);
+ session_flags(session)->connected = false;
+
+ struct tls_client_ctx_t *tls_client_ctx = session_tls_get_client_ctx(session);
+ if (tls_client_ctx) {
+ /* Avoid gnutls_bye() call */
+ tls_set_hs_state(&tls_client_ctx->c, TLS_HS_NOT_STARTED);
+ }
+
+ struct tls_ctx_t *tls_ctx = session_tls_get_server_ctx(session);
+ if (tls_ctx) {
+ /* Avoid gnutls_bye() call */
+ tls_set_hs_state(&tls_ctx->c, TLS_HS_NOT_STARTED);
+ }
+
+ while (!session_waitinglist_is_empty(session)) {
+ struct qr_task *task = session_waitinglist_pop(session, false);
+ assert(task->refs > 1);
+ session_tasklist_del(session, task);
+ if (session_flags(session)->outgoing) {
+ if (task->ctx->req.options.FORWARD) {
+ /* We are in TCP_FORWARD mode.
+ * To prevent failing at kr_resolve_consume()
+ * qry.flags.TCP must be cleared.
+ * TODO - refactoring is needed. */
+ struct kr_request *req = &task->ctx->req;
+ struct kr_rplan *rplan = &req->rplan;
+ struct kr_query *qry = array_tail(rplan->pending);
+ qry->flags.TCP = false;
+ }
+ qr_task_step(task, NULL, NULL);
+ } else {
+ assert(task->ctx->source.session == session);
+ task->ctx->source.session = NULL;
+ }
+ worker_task_unref(task);
+ }
+ while (!session_tasklist_is_empty(session)) {
+ struct qr_task *task = session_tasklist_del_first(session, false);
+ if (session_flags(session)->outgoing) {
+ if (task->ctx->req.options.FORWARD) {
+ struct kr_request *req = &task->ctx->req;
+ struct kr_rplan *rplan = &req->rplan;
+ struct kr_query *qry = array_tail(rplan->pending);
+ qry->flags.TCP = false;
+ }
+ qr_task_step(task, NULL, NULL);
+ } else {
+ assert(task->ctx->source.session == session);
+ task->ctx->source.session = NULL;
+ }
+ worker_task_unref(task);
+ }
+ session_close(session);
+ return kr_ok();
+}
+
+knot_pkt_t * worker_resolve_mk_pkt(const char *qname_str, uint16_t qtype, uint16_t qclass,
+ const struct kr_qflags *options)
+{
+ uint8_t qname[KNOT_DNAME_MAXLEN];
+ if (!knot_dname_from_str(qname, qname_str, sizeof(qname)))
+ return NULL;
+ knot_pkt_t *pkt = knot_pkt_new(NULL, KNOT_EDNS_MAX_UDP_PAYLOAD, NULL);
+ if (!pkt)
+ return NULL;
+ knot_pkt_put_question(pkt, qname, qclass, qtype);
+ knot_wire_set_rd(pkt->wire);
+ knot_wire_set_ad(pkt->wire);
+
+ /* Add OPT RR */
+ pkt->opt_rr = knot_rrset_copy(the_worker->engine->resolver.opt_rr, NULL);
+ if (!pkt->opt_rr) {
+ knot_pkt_free(pkt);
+ return NULL;
+ }
+ if (options->DNSSEC_WANT) {
+ knot_edns_set_do(pkt->opt_rr);
+ }
+ if (options->DNSSEC_CD) {
+ knot_wire_set_cd(pkt->wire);
+ }
+
+ return pkt;
+}
+
+struct qr_task *worker_resolve_start(knot_pkt_t *query, struct kr_qflags options)
+{
+ struct worker_ctx *worker = the_worker;
+ if (!worker || !query) {
+ assert(!EINVAL);
+ return NULL;
+ }
+
+
+ struct request_ctx *ctx = request_create(worker, NULL, NULL, worker->next_request_uid);
+ if (!ctx) {
+ return NULL;
+ }
+
+ /* Create task */
+ struct qr_task *task = qr_task_create(ctx);
+ if (!task) {
+ request_free(ctx);
+ return NULL;
+ }
+
+ /* Start task */
+ int ret = request_start(ctx, query);
+ if (ret != 0) {
+ /* task is attached to request context,
+ * so dereference (and deallocate) it first */
+ ctx->task = NULL;
+ qr_task_unref(task);
+ request_free(ctx);
+ return NULL;
+ }
+
+ worker->next_request_uid += 1;
+ if (worker->next_request_uid == 0) {
+ worker->next_request_uid = UINT16_MAX + 1;
+ }
+
+ /* Set options late, as qr_task_start() -> kr_resolve_begin() rewrite it. */
+ kr_qflags_set(&task->ctx->req.options, options);
+ return task;
+}
+
+int worker_resolve_exec(struct qr_task *task, knot_pkt_t *query)
+{
+ if (!task) {
+ return kr_error(EINVAL);
+ }
+ return qr_task_step(task, NULL, query);
+}
+
+int worker_task_numrefs(const struct qr_task *task)
+{
+ return task->refs;
+}
+
+struct kr_request *worker_task_request(struct qr_task *task)
+{
+ if (!task || !task->ctx) {
+ return NULL;
+ }
+
+ return &task->ctx->req;
+}
+
+int worker_task_finalize(struct qr_task *task, int state)
+{
+ return qr_task_finalize(task, state);
+}
+
+ int worker_task_step(struct qr_task *task, const struct sockaddr *packet_source,
+ knot_pkt_t *packet)
+ {
+ return qr_task_step(task, packet_source, packet);
+ }
+
+void worker_task_complete(struct qr_task *task)
+{
+ qr_task_complete(task);
+}
+
+void worker_task_ref(struct qr_task *task)
+{
+ qr_task_ref(task);
+}
+
+void worker_task_unref(struct qr_task *task)
+{
+ qr_task_unref(task);
+}
+
+void worker_task_timeout_inc(struct qr_task *task)
+{
+ task->timeouts += 1;
+}
+
+knot_pkt_t *worker_task_get_pktbuf(const struct qr_task *task)
+{
+ return task->pktbuf;
+}
+
+struct request_ctx *worker_task_get_request(struct qr_task *task)
+{
+ return task->ctx;
+}
+
+struct session *worker_request_get_source_session(struct request_ctx *ctx)
+{
+ return ctx->source.session;
+}
+
+void worker_request_set_source_session(struct request_ctx *ctx, struct session *session)
+{
+ ctx->source.session = session;
+}
+
+uint16_t worker_task_pkt_get_msgid(struct qr_task *task)
+{
+ knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
+ uint16_t msg_id = knot_wire_get_id(pktbuf->wire);
+ return msg_id;
+}
+
+void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid)
+{
+ knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
+ knot_wire_set_id(pktbuf->wire, msgid);
+ struct kr_query *q = task_get_last_pending_query(task);
+ q->id = msgid;
+}
+
+uint64_t worker_task_creation_time(struct qr_task *task)
+{
+ return task->creation_time;
+}
+
+void worker_task_subreq_finalize(struct qr_task *task)
+{
+ subreq_finalize(task, NULL, NULL);
+}
+
+bool worker_task_finished(struct qr_task *task)
+{
+ return task->finished;
+}
+
+/** Reserve worker buffers. We assume worker's been zeroed. */
+static int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen)
+{
+ worker->tcp_connected = map_make(NULL);
+ worker->tcp_waiting = map_make(NULL);
+ worker->subreq_out = trie_create(NULL);
+
+ array_init(worker->pool_mp);
+ if (array_reserve(worker->pool_mp, ring_maxlen)) {
+ return kr_error(ENOMEM);
+ }
+
+ worker->pkt_pool.ctx = mp_new (4 * sizeof(knot_pkt_t));
+ worker->pkt_pool.alloc = (knot_mm_alloc_t) mp_alloc;
+
+ return kr_ok();
+}
+
+static inline void reclaim_mp_freelist(mp_freelist_t *list)
+{
+ for (unsigned i = 0; i < list->len; ++i) {
+ struct mempool *e = list->at[i];
+ kr_asan_unpoison(e, sizeof(*e));
+ mp_delete(e);
+ }
+ array_clear(*list);
+}
+
+void worker_deinit(void)
+{
+ struct worker_ctx *worker = the_worker;
+ assert(worker);
+ if (worker->z_import != NULL) {
+ zi_free(worker->z_import);
+ worker->z_import = NULL;
+ }
+ map_clear(&worker->tcp_connected);
+ map_clear(&worker->tcp_waiting);
+ trie_free(worker->subreq_out);
+ worker->subreq_out = NULL;
+
+ reclaim_mp_freelist(&worker->pool_mp);
+ mp_delete(worker->pkt_pool.ctx);
+ worker->pkt_pool.ctx = NULL;
+
+ the_worker = NULL;
+}
+
+int worker_init(struct engine *engine, int worker_id, int worker_count)
+{
+ assert(engine && engine->L);
+ assert(the_worker == NULL);
+ kr_bindings_register(engine->L);
+
+ /* Create main worker. */
+ struct worker_ctx *worker = &the_worker_value;
+ memset(worker, 0, sizeof(*worker));
+ worker->engine = engine;
+
+ uv_loop_t *loop = uv_default_loop();
+ worker->loop = loop;
+
+ worker->id = worker_id;
+ worker->count = worker_count;
+
+ /* Register table for worker per-request variables */
+ lua_newtable(engine->L);
+ lua_setfield(engine->L, -2, "vars");
+ lua_getfield(engine->L, -1, "vars");
+ worker->vars_table_ref = luaL_ref(engine->L, LUA_REGISTRYINDEX);
+ lua_pop(engine->L, 1);
+
+ worker->tcp_pipeline_max = MAX_PIPELINED;
+ worker->out_addr4.sin_family = AF_UNSPEC;
+ worker->out_addr6.sin6_family = AF_UNSPEC;
+
+ int ret = worker_reserve(worker, MP_FREELIST_SIZE);
+ if (ret) return ret;
+ worker->next_request_uid = UINT16_MAX + 1;
+
+ /* Set some worker.* fields in Lua */
+ lua_getglobal(engine->L, "worker");
+ lua_pushnumber(engine->L, worker_id);
+ lua_setfield(engine->L, -2, "id");
+ lua_pushnumber(engine->L, getpid());
+ lua_setfield(engine->L, -2, "pid");
+ lua_pushnumber(engine->L, worker_count);
+ lua_setfield(engine->L, -2, "count");
+
+ the_worker = worker;
+ loop->data = the_worker;
+ /* ^^^^ This shouldn't be used anymore, but it's hard to be 100% sure. */
+ return kr_ok();
+}
+
+#undef VERBOSE_MSG
--- /dev/null
+/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include "engine.h"
+#include "lib/generic/array.h"
+#include "lib/generic/map.h"
+
+
+/** Query resolution task (opaque). */
+struct qr_task;
+/** Worker state (opaque). */
+struct worker_ctx;
+/** Transport session (opaque). */
+struct session;
+/** Zone import context (opaque). */
+struct zone_import_ctx;
+
+/** Pointer to the singleton worker. NULL if not initialized. */
+KR_EXPORT extern struct worker_ctx *the_worker;
+
+/** Create and initialize the worker.
+ * \return error code (ENOMEM) */
+int worker_init(struct engine *engine, int worker_id, int worker_count);
+
+/** Destroy the worker (free memory). */
+void worker_deinit(void);
+
+/**
+ * Process an incoming packet (query from a client or answer from upstream).
+ *
+ * @param session session the packet came from
+ * @param peer address the packet came from
+ * @param query the packet, or NULL on an error from the transport layer
+ * @return 0 or an error code
+ */
+int worker_submit(struct session *session, const struct sockaddr *peer, knot_pkt_t *query);
+
+/**
+ * End current DNS/TCP session, this disassociates pending tasks from this session
+ * which may be freely closed afterwards.
+ */
+int worker_end_tcp(struct session *session);
+
+/**
+ * Create a packet suitable for worker_resolve_start(). All in malloc() memory.
+ */
+KR_EXPORT knot_pkt_t *
+worker_resolve_mk_pkt(const char *qname_str, uint16_t qtype, uint16_t qclass,
+ const struct kr_qflags *options);
+
+/**
+ * Start query resolution with given query.
+ *
+ * @return task or NULL
+ */
+KR_EXPORT struct qr_task *
+worker_resolve_start(knot_pkt_t *query, struct kr_qflags options);
+
+/**
+ * Execute a request with given query.
+ * It expects task to be created with \fn worker_resolve_start.
+ *
+ * @return 0 or an error code
+ */
+KR_EXPORT int worker_resolve_exec(struct qr_task *task, knot_pkt_t *query);
+
+/** @return struct kr_request associated with opaque task */
+struct kr_request *worker_task_request(struct qr_task *task);
+
+int worker_task_step(struct qr_task *task, const struct sockaddr *packet_source,
+ knot_pkt_t *packet);
+
+int worker_task_numrefs(const struct qr_task *task);
+
+/** Finalize given task */
+int worker_task_finalize(struct qr_task *task, int state);
+
+void worker_task_complete(struct qr_task *task);
+
+void worker_task_ref(struct qr_task *task);
+
+void worker_task_unref(struct qr_task *task);
+
+void worker_task_timeout_inc(struct qr_task *task);
+
+int worker_add_tcp_connected(struct worker_ctx *worker,
+ const struct sockaddr *addr,
+ struct session *session);
+int worker_del_tcp_connected(struct worker_ctx *worker,
+ const struct sockaddr *addr);
+int worker_del_tcp_waiting(struct worker_ctx *worker,
+ const struct sockaddr* addr);
+knot_pkt_t *worker_task_get_pktbuf(const struct qr_task *task);
+
+struct request_ctx *worker_task_get_request(struct qr_task *task);
+
+struct session *worker_request_get_source_session(struct request_ctx *);
+
+void worker_request_set_source_session(struct request_ctx *, struct session *session);
+
+uint16_t worker_task_pkt_get_msgid(struct qr_task *task);
+void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid);
+uint64_t worker_task_creation_time(struct qr_task *task);
+void worker_task_subreq_finalize(struct qr_task *task);
+bool worker_task_finished(struct qr_task *task);
+
+/** To be called after sending a DNS message. It mainly deals with cleanups. */
+int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status);
+
+/** Various worker statistics. Sync with wrk_stats() */
+struct worker_stats {
+ size_t queries; /**< Total number of requests (from clients and internal ones). */
+ size_t concurrent; /**< The number of requests currently in processing. */
+ size_t rconcurrent; /*< TODO: remove? I see no meaningful difference from .concurrent. */
+ size_t dropped; /**< The number of requests dropped due to being badly formed. See #471. */
+
+ size_t timeout; /**< Number of outbound queries that timed out. */
+ size_t udp; /**< Number of outbound queries over UDP. */
+ size_t tcp; /**< Number of outbound queries over TCP (excluding TLS). */
+ size_t tls; /**< Number of outbound queries over TLS. */
+ size_t ipv4; /**< Number of outbound queries over IPv4.*/
+ size_t ipv6; /**< Number of outbound queries over IPv6. */
+};
+
+/** @cond internal */
+
+/** Number of request within timeout window. */
+#define MAX_PENDING KR_NSREP_MAXADDR
+
+/** Maximum response time from TCP upstream, milliseconds */
+#define MAX_TCP_INACTIVITY (KR_RESOLVE_TIME_LIMIT + KR_CONN_RTT_MAX)
+
+#ifndef RECVMMSG_BATCH /* see check_bufsize() */
+#define RECVMMSG_BATCH 1
+#endif
+
+/** Freelist of available mempools. */
+typedef array_t(struct mempool *) mp_freelist_t;
+
+/** List of query resolution tasks. */
+typedef array_t(struct qr_task *) qr_tasklist_t;
+
+/** \details Worker state is meant to persist during the whole life of daemon. */
+struct worker_ctx {
+ struct engine *engine;
+ uv_loop_t *loop;
+ int id;
+ int count;
+ int vars_table_ref;
+ unsigned tcp_pipeline_max;
+
+ /** Addresses to bind for outgoing connections or AF_UNSPEC. */
+ struct sockaddr_in out_addr4;
+ struct sockaddr_in6 out_addr6;
+
+ uint8_t wire_buf[RECVMMSG_BATCH * KNOT_WIRE_MAX_PKTSIZE];
+
+ struct worker_stats stats;
+
+ struct zone_import_ctx* z_import;
+ bool too_many_open;
+ size_t rconcurrent_highwatermark;
+ /** List of active outbound TCP sessions */
+ map_t tcp_connected;
+ /** List of outbound TCP sessions waiting to be accepted */
+ map_t tcp_waiting;
+ /** Subrequest leaders (struct qr_task*), indexed by qname+qtype+qclass. */
+ trie_t *subreq_out;
+ mp_freelist_t pool_mp;
+ knot_mm_t pkt_pool;
+ unsigned int next_request_uid;
+};
+
+/** @endcond */
+
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/* Module is intended to import resource records from file into resolver's cache.
+ * File supposed to be a standard DNS zone file
+ * which contains text representations of resource records.
+ * For now only root zone import is supported.
+ *
+ * Import process consists of two stages.
+ * 1) Zone file parsing.
+ * 2) Import of parsed entries into the cache.
+ *
+ * These stages are implemented as two separate functions
+ * (zi_zone_import and zi_zone_process) which runs sequentially with the
+ * pause between them. This is done because resolver is a single-threaded
+ * application, so it can't process user's requests during the whole import
+ * process. Separation into two stages allows to reduce the
+ * continuous time interval when resolver can't serve user requests.
+ * Since root zone isn't large it is imported as single
+ * chunk. If it would be considered as necessary, import stage can be
+ * split into shorter stages.
+ *
+ * zi_zone_import() uses libzscanner to parse zone file.
+ * Parsed records are stored to internal storage from where they are imported to
+ * cache during the second stage.
+ *
+ * zi_zone_process() imports parsed resource records to cache.
+ * It imports rrset by creating request that will never be sent to upstream.
+ * After request creation resolver creates pseudo-answer which must contain
+ * all necessary data for validation. Then resolver process answer as if he had
+ * been received from network.
+ */
+
+#include <inttypes.h> /* PRIu64 */
+#include <stdlib.h>
+#include <uv.h>
+#include <ucw/mempool.h>
+#include <libknot/rrset.h>
+#include <libzscanner/scanner.h>
+
+#include "lib/utils.h"
+#include "lib/dnssec/ta.h"
+#include "worker.h"
+#include "zimport.h"
+#include "lib/generic/map.h"
+#include "lib/generic/array.h"
+
+#define VERBOSE_MSG(qry, ...) QRVERBOSE(qry, "zimport", __VA_ARGS__)
+
+/* Pause between parse and import stages, milliseconds.
+ * See comment in zi_zone_import() */
+#define ZONE_IMPORT_PAUSE 100
+
+typedef array_t(knot_rrset_t *) qr_rrsetlist_t;
+
+struct zone_import_ctx {
+ struct worker_ctx *worker;
+ bool started;
+ knot_dname_t *origin;
+ knot_rrset_t *ta;
+ knot_rrset_t *key;
+ uint64_t start_timestamp;
+ size_t rrset_idx;
+ uv_timer_t timer;
+ map_t rrset_indexed;
+ qr_rrsetlist_t rrset_sorted;
+ knot_mm_t pool;
+ zi_callback cb;
+ void *cb_param;
+};
+
+typedef struct zone_import_ctx zone_import_ctx_t;
+
+static int RRSET_IS_ALREADY_IMPORTED = 1;
+
+/** @internal Allocate zone import context.
+ * @return pointer to zone import context or NULL. */
+static zone_import_ctx_t *zi_ctx_alloc()
+{
+ return (zone_import_ctx_t *)malloc(sizeof(zone_import_ctx_t));
+}
+
+/** @internal Free zone import context. */
+static void zi_ctx_free(zone_import_ctx_t *z_import)
+{
+ if (z_import != NULL) {
+ free(z_import);
+ }
+}
+
+/** @internal Reset all fields in the zone import context to their default values.
+ * Flushes memory pool, but doesn't reallocate memory pool buffer.
+ * Doesn't affect timer handle, pointers to callback and callback parameter.
+ * @return 0 if success; -1 if failed. */
+static int zi_reset(struct zone_import_ctx *z_import, size_t rrset_sorted_list_size)
+{
+ mp_flush(z_import->pool.ctx);
+
+ z_import->started = false;
+ z_import->start_timestamp = 0;
+ z_import->rrset_idx = 0;
+ z_import->pool.alloc = (knot_mm_alloc_t) mp_alloc;
+ z_import->rrset_indexed = map_make(&z_import->pool);
+
+ array_init(z_import->rrset_sorted);
+
+ int ret = 0;
+ if (rrset_sorted_list_size) {
+ ret = array_reserve_mm(z_import->rrset_sorted, rrset_sorted_list_size,
+ kr_memreserve, &z_import->pool);
+ }
+
+ return ret;
+}
+
+/** @internal Close callback for timer handle.
+ * @note Actually frees zone import context. */
+static void on_timer_close(uv_handle_t *handle)
+{
+ zone_import_ctx_t *z_import = (zone_import_ctx_t *)handle->data;
+ if (z_import != NULL) {
+ zi_ctx_free(z_import);
+ }
+}
+
+zone_import_ctx_t *zi_allocate(struct worker_ctx *worker,
+ zi_callback cb, void *param)
+{
+ if (worker->loop == NULL) {
+ return NULL;
+ }
+ zone_import_ctx_t *z_import = zi_ctx_alloc();
+ if (!z_import) {
+ return NULL;
+ }
+ void *mp = mp_new (8192);
+ if (!mp) {
+ zi_ctx_free(z_import);
+ return NULL;
+ }
+ memset(z_import, 0, sizeof(*z_import));
+ z_import->pool.ctx = mp;
+ z_import->worker = worker;
+ int ret = zi_reset(z_import, 0);
+ if (ret < 0) {
+ mp_delete(mp);
+ zi_ctx_free(z_import);
+ return NULL;
+ }
+ uv_timer_init(z_import->worker->loop, &z_import->timer);
+ z_import->timer.data = z_import;
+ z_import->cb = cb;
+ z_import->cb_param = param;
+ return z_import;
+}
+
+void zi_free(zone_import_ctx_t *z_import)
+{
+ z_import->started = false;
+ z_import->start_timestamp = 0;
+ z_import->rrset_idx = 0;
+ mp_delete(z_import->pool.ctx);
+ z_import->pool.ctx = NULL;
+ z_import->pool.alloc = NULL;
+ z_import->worker = NULL;
+ z_import->cb = NULL;
+ z_import->cb_param = NULL;
+ uv_close((uv_handle_t *)&z_import->timer, on_timer_close);
+}
+
+/** @internal Mark rrset that has been already imported
+ * to avoid repeated import. */
+static inline void zi_rrset_mark_as_imported(knot_rrset_t *rr)
+{
+ rr->additional = (void *)&RRSET_IS_ALREADY_IMPORTED;
+}
+
+/** @internal Check if rrset is marked as "already imported".
+ * @return true if marked, false if isn't */
+static inline bool zi_rrset_is_marked_as_imported(knot_rrset_t *rr)
+{
+ return (rr->additional == &RRSET_IS_ALREADY_IMPORTED);
+}
+
+/** @internal Try to find rrset with given requisites amongst parsed rrsets
+ * and put it to given packet. If there is RRSIG which covers that rrset, it
+ * will be added as well. If rrset found and successfully put, it marked as
+ * "already imported" to avoid repeated import. The same is true for RRSIG.
+ * @return -1 if failed
+ * 0 if required record been actually put into the packet
+ * 1 if required record could not be found */
+static int zi_rrset_find_put(struct zone_import_ctx *z_import,
+ knot_pkt_t *pkt, const knot_dname_t *owner,
+ uint16_t class, uint16_t type, uint16_t additional)
+{
+ if (type != KNOT_RRTYPE_RRSIG) {
+ /* If required rrset isn't rrsig, these must be the same values */
+ additional = type;
+ }
+
+ char key[KR_RRKEY_LEN];
+ int err = kr_rrkey(key, class, owner, type, additional);
+ if (err <= 0) {
+ return -1;
+ }
+ knot_rrset_t *rr = map_get(&z_import->rrset_indexed, key);
+ if (!rr) {
+ return 1;
+ }
+ err = knot_pkt_put(pkt, 0, rr, 0);
+ if (err != KNOT_EOK) {
+ return -1;
+ }
+ zi_rrset_mark_as_imported(rr);
+
+ if (type != KNOT_RRTYPE_RRSIG) {
+ /* Try to find corresponding rrsig */
+ err = zi_rrset_find_put(z_import, pkt, owner,
+ class, KNOT_RRTYPE_RRSIG, type);
+ if (err < 0) {
+ return err;
+ }
+ }
+
+ return 0;
+}
+
+/** @internal Try to put given rrset to the given packet.
+ * If there is RRSIG which covers that rrset, it will be added as well.
+ * If rrset successfully put in the packet, it marked as
+ * "already imported" to avoid repeated import.
+ * The same is true for RRSIG.
+ * @return -1 if failed
+ * 0 if required record been actually put into the packet */
+static int zi_rrset_put(struct zone_import_ctx *z_import, knot_pkt_t *pkt,
+ knot_rrset_t *rr)
+{
+ assert(rr);
+ assert(rr->type != KNOT_RRTYPE_RRSIG);
+ int err = knot_pkt_put(pkt, 0, rr, 0);
+ if (err != KNOT_EOK) {
+ return -1;
+ }
+ zi_rrset_mark_as_imported(rr);
+ /* Try to find corresponding RRSIG */
+ err = zi_rrset_find_put(z_import, pkt, rr->owner, rr->rclass,
+ KNOT_RRTYPE_RRSIG, rr->type);
+ return (err < 0) ? err : 0;
+}
+
+/** @internal Try to put DS & NSEC* for rset->owner to given packet.
+ * @return -1 if failed;
+ * 0 if no errors occurred (it doesn't mean
+ * that records were actually added). */
+static int zi_put_delegation(zone_import_ctx_t *z_import, knot_pkt_t *pkt,
+ knot_rrset_t *rr)
+{
+ int err = zi_rrset_find_put(z_import, pkt, rr->owner,
+ rr->rclass, KNOT_RRTYPE_DS, 0);
+ if (err == 1) {
+ /* DS not found, maybe there are NSEC* */
+ err = zi_rrset_find_put(z_import, pkt, rr->owner,
+ rr->rclass, KNOT_RRTYPE_NSEC, 0);
+ if (err >= 0) {
+ err = zi_rrset_find_put(z_import, pkt, rr->owner,
+ rr->rclass, KNOT_RRTYPE_NSEC3, 0);
+ }
+ }
+ return err < 0 ? err : 0;
+}
+
+/** @internal Try to put A & AAAA records for rset->owner to given packet.
+ * @return -1 if failed;
+ * 0 if no errors occurred (it doesn't mean
+ * that records were actually added). */
+static int zi_put_glue(zone_import_ctx_t *z_import, knot_pkt_t *pkt,
+ knot_rrset_t *rr)
+{
+ int err = 0;
+ knot_rdata_t *rdata_i = rr->rrs.rdata;
+ for (uint16_t i = 0; i < rr->rrs.count;
+ ++i, rdata_i = knot_rdataset_next(rdata_i)) {
+ const knot_dname_t *ns_name = knot_ns_name(rdata_i);
+ err = zi_rrset_find_put(z_import, pkt, ns_name,
+ rr->rclass, KNOT_RRTYPE_A, 0);
+ if (err < 0) {
+ break;
+ }
+
+ err = zi_rrset_find_put(z_import, pkt, ns_name,
+ rr->rclass, KNOT_RRTYPE_AAAA, 0);
+ if (err < 0) {
+ break;
+ }
+ }
+ return err < 0 ? err : 0;
+}
+
+/** @internal Create query. */
+static knot_pkt_t *zi_query_create(zone_import_ctx_t *z_import, knot_rrset_t *rr)
+{
+ knot_mm_t *pool = &z_import->pool;
+
+ uint32_t msgid = kr_rand_bytes(2);
+
+ knot_pkt_t *query = knot_pkt_new(NULL, KNOT_WIRE_MAX_PKTSIZE, pool);
+ if (!query) {
+ return NULL;
+ }
+
+ knot_pkt_put_question(query, rr->owner, rr->rclass, rr->type);
+ knot_pkt_begin(query, KNOT_ANSWER);
+ knot_wire_set_rd(query->wire);
+ knot_wire_set_id(query->wire, msgid);
+ int err = knot_pkt_parse(query, 0);
+ if (err != KNOT_EOK) {
+ knot_pkt_free(query);
+ return NULL;
+ }
+
+ return query;
+}
+
+/** @internal Import given rrset to cache.
+ * @return -1 if failed; 0 if success */
+static int zi_rrset_import(zone_import_ctx_t *z_import, knot_rrset_t *rr)
+{
+ /* Create "pseudo query" which asks for given rrset. */
+ knot_pkt_t *query = zi_query_create(z_import, rr);
+ if (!query) {
+ return -1;
+ }
+
+ knot_mm_t *pool = &z_import->pool;
+ uint8_t *dname = rr->owner;
+ uint16_t rrtype = rr->type;
+ uint16_t rrclass = rr->rclass;
+
+ /* Create "pseudo answer". */
+ knot_pkt_t *answer = knot_pkt_new(NULL, KNOT_WIRE_MAX_PKTSIZE, pool);
+ if (!answer) {
+ knot_pkt_free(query);
+ return -1;
+ }
+ knot_pkt_put_question(answer, dname, rrclass, rrtype);
+ knot_pkt_begin(answer, KNOT_ANSWER);
+
+ struct kr_qflags options;
+ memset(&options, 0, sizeof(options));
+ options.DNSSEC_WANT = true;
+ options.NO_MINIMIZE = true;
+
+ /* This call creates internal structures which necessary for
+ * resolving - qr_task & request_ctx. */
+ struct qr_task *task = worker_resolve_start(query, options);
+ if (!task) {
+ knot_pkt_free(query);
+ knot_pkt_free(answer);
+ return -1;
+ }
+
+ /* Push query to the request resolve plan.
+ * Actually query will never been sent to upstream. */
+ struct kr_request *request = worker_task_request(task);
+ struct kr_rplan *rplan = &request->rplan;
+ struct kr_query *qry = kr_rplan_push(rplan, NULL, dname, rrclass, rrtype);
+ int state = KR_STATE_FAIL;
+ bool origin_is_owner = knot_dname_is_equal(rr->owner, z_import->origin);
+ bool is_referral = (rrtype == KNOT_RRTYPE_NS && !origin_is_owner);
+ uint32_t msgid = knot_wire_get_id(query->wire);
+
+ qry->id = msgid;
+
+ /* Prepare zonecut. It must have all the necessary requisites for
+ * successful validation - matched zone name & keys & trust-anchors. */
+ kr_zonecut_init(&qry->zone_cut, z_import->origin, pool);
+ qry->zone_cut.key = z_import->key;
+ qry->zone_cut.trust_anchor = z_import->ta;
+
+ if (knot_pkt_init_response(request->answer, query) != 0) {
+ goto cleanup;
+ }
+
+ /* Since "pseudo" query asks for NS for subzone,
+ * "pseudo" answer must simulate referral. */
+ if (is_referral) {
+ knot_pkt_begin(answer, KNOT_AUTHORITY);
+ }
+
+ /* Put target rrset to ANSWER\AUTHORIRY as well as corresponding RRSIG */
+ int err = zi_rrset_put(z_import, answer, rr);
+ if (err != 0) {
+ goto cleanup;
+ }
+
+ if (!is_referral) {
+ knot_wire_set_aa(answer->wire);
+ } else {
+ /* Type is KNOT_RRTYPE_NS and owner is not equal to origin.
+ * It will be "referral" answer and must contain delegation. */
+ err = zi_put_delegation(z_import, answer, rr);
+ if (err < 0) {
+ goto cleanup;
+ }
+ }
+
+ knot_pkt_begin(answer, KNOT_ADDITIONAL);
+
+ if (rrtype == KNOT_RRTYPE_NS) {
+ /* Try to find glue addresses. */
+ err = zi_put_glue(z_import, answer, rr);
+ if (err < 0) {
+ goto cleanup;
+ }
+ }
+
+ knot_wire_set_id(answer->wire, msgid);
+ answer->parsed = answer->size;
+ err = knot_pkt_parse(answer, 0);
+ if (err != KNOT_EOK) {
+ goto cleanup;
+ }
+
+ /* Importing doesn't imply communication with upstream at all.
+ * "answer" contains pseudo-answer from upstream and must be successfully
+ * validated in CONSUME stage. If not, something gone wrong. */
+ state = kr_resolve_consume(request, NULL, answer);
+
+cleanup:
+
+ knot_pkt_free(query);
+ knot_pkt_free(answer);
+ worker_task_finalize(task, state);
+ return state == (is_referral ? KR_STATE_PRODUCE : KR_STATE_DONE) ? 0 : -1;
+}
+
+/** @internal Create element in qr_rrsetlist_t rrset_list for
+ * given node of map_t rrset_sorted. */
+static int zi_mapwalk_preprocess(const char *k, void *v, void *baton)
+{
+ zone_import_ctx_t *z_import = (zone_import_ctx_t *)baton;
+
+ int ret = array_push_mm(z_import->rrset_sorted, v, kr_memreserve, &z_import->pool);
+
+ return (ret < 0);
+}
+
+/** @internal Iterate over parsed rrsets and try to import each of them. */
+static void zi_zone_process(uv_timer_t* handle)
+{
+ zone_import_ctx_t *z_import = (zone_import_ctx_t *)handle->data;
+
+ assert(z_import->worker);
+
+ size_t failed = 0;
+ size_t ns_imported = 0;
+ size_t other_imported = 0;
+
+ /* At the moment import of root zone only is supported.
+ * Check the name of the parsed zone.
+ * TODO - implement importing of arbitrary zone. */
+ KR_DNAME_GET_STR(zone_name_str, z_import->origin);
+
+ if (strcmp(".", zone_name_str) != 0) {
+ kr_log_error("[zimport] unexpected zone name `%s` (root zone expected), fail\n",
+ zone_name_str);
+ failed = 1;
+ goto finish;
+ }
+
+ if (z_import->rrset_sorted.len <= 0) {
+ VERBOSE_MSG(NULL, "zone is empty\n");
+ goto finish;
+ }
+
+ /* TA have been found, zone is secured.
+ * DNSKEY must be somewhere amongst the imported records. Find it.
+ * TODO - For those zones that provenly do not have TA this step must be skipped. */
+ char key[KR_RRKEY_LEN];
+ int err = kr_rrkey(key, KNOT_CLASS_IN, z_import->origin,
+ KNOT_RRTYPE_DNSKEY, KNOT_RRTYPE_DNSKEY);
+ if (err <= 0) {
+ failed = 1;
+ goto finish;
+ }
+
+ knot_rrset_t *rr_key = map_get(&z_import->rrset_indexed, key);
+ if (!rr_key) {
+ /* DNSKEY MUST be here. If not found - fail. */
+ kr_log_error("[zimport] DNSKEY not found for `%s`, fail\n", zone_name_str);
+ failed = 1;
+ goto finish;
+ }
+ z_import->key = rr_key;
+
+ VERBOSE_MSG(NULL, "started: zone: '%s'\n", zone_name_str);
+
+ z_import->start_timestamp = kr_now();
+
+ /* Import DNSKEY at first step. If any validation problems will appear,
+ * cancel import of whole zone. */
+ KR_DNAME_GET_STR(kname_str, rr_key->owner);
+ KR_RRTYPE_GET_STR(ktype_str, rr_key->type);
+
+ VERBOSE_MSG(NULL, "importing: name: '%s' type: '%s'\n",
+ kname_str, ktype_str);
+
+ int res = zi_rrset_import(z_import, rr_key);
+ if (res != 0) {
+ VERBOSE_MSG(NULL, "import failed: qname: '%s' type: '%s'\n",
+ kname_str, ktype_str);
+ failed = 1;
+ goto finish;
+ }
+
+ /* Import all NS records */
+ for (size_t i = 0; i < z_import->rrset_sorted.len; ++i) {
+ knot_rrset_t *rr = z_import->rrset_sorted.at[i];
+
+ if (rr->type != KNOT_RRTYPE_NS) {
+ continue;
+ }
+
+ KR_DNAME_GET_STR(name_str, rr->owner);
+ KR_RRTYPE_GET_STR(type_str, rr->type);
+ VERBOSE_MSG(NULL, "importing: name: '%s' type: '%s'\n",
+ name_str, type_str);
+ int ret = zi_rrset_import(z_import, rr);
+ if (ret == 0) {
+ ++ns_imported;
+ } else {
+ VERBOSE_MSG(NULL, "import failed: name: '%s' type: '%s'\n",
+ name_str, type_str);
+ ++failed;
+ }
+ z_import->rrset_sorted.at[i] = NULL;
+ }
+
+ /* NS records have been imported as well as relative DS, NSEC* and glue.
+ * Now import what's left. */
+ for (size_t i = 0; i < z_import->rrset_sorted.len; ++i) {
+
+ knot_rrset_t *rr = z_import->rrset_sorted.at[i];
+ if (rr == NULL) {
+ continue;
+ }
+
+ if (zi_rrset_is_marked_as_imported(rr)) {
+ continue;
+ }
+
+ if (rr->type == KNOT_RRTYPE_DNSKEY || rr->type == KNOT_RRTYPE_RRSIG) {
+ continue;
+ }
+
+ KR_DNAME_GET_STR(name_str, rr->owner);
+ KR_RRTYPE_GET_STR(type_str, rr->type);
+ VERBOSE_MSG(NULL, "importing: name: '%s' type: '%s'\n",
+ name_str, type_str);
+ res = zi_rrset_import(z_import, rr);
+ if (res == 0) {
+ ++other_imported;
+ } else {
+ VERBOSE_MSG(NULL, "import failed: name: '%s' type: '%s'\n",
+ name_str, type_str);
+ ++failed;
+ }
+ }
+
+ uint64_t elapsed = kr_now() - z_import->start_timestamp;
+ elapsed = elapsed > UINT_MAX ? UINT_MAX : elapsed;
+
+ VERBOSE_MSG(NULL, "finished in %"PRIu64" ms; zone: `%s`; ns: %zd"
+ "; other: %zd; failed: %zd\n",
+ elapsed, zone_name_str, ns_imported, other_imported, failed);
+
+finish:
+
+ uv_timer_stop(&z_import->timer);
+ z_import->started = false;
+
+ int import_state = 0;
+
+ if (failed != 0) {
+ if (ns_imported == 0 && other_imported == 0) {
+ import_state = -1;
+ VERBOSE_MSG(NULL, "import failed; zone `%s` \n", zone_name_str);
+ } else {
+ import_state = 1;
+ }
+ } else {
+ import_state = 0;
+ }
+
+ if (z_import->cb != NULL) {
+ z_import->cb(import_state, z_import->cb_param);
+ }
+}
+
+/** @internal Store rrset that has been imported to zone import context memory pool.
+ * @return -1 if failed; 0 if success. */
+static int zi_record_store(zs_scanner_t *s)
+{
+ if (s->r_data_length > UINT16_MAX) {
+ /* Due to knot_rrset_add_rdata(..., const uint16_t size, ...); */
+ kr_log_error("[zscanner] line %"PRIu64": rdata is too long\n",
+ s->line_counter);
+ return -1;
+ }
+
+ if (knot_dname_size(s->r_owner) != strlen((const char *)(s->r_owner)) + 1) {
+ kr_log_error("[zscanner] line %"PRIu64
+ ": owner name contains zero byte, skip\n",
+ s->line_counter);
+ return 0;
+ }
+
+ zone_import_ctx_t *z_import = (zone_import_ctx_t *)s->process.data;
+
+ knot_rrset_t *new_rr = knot_rrset_new(s->r_owner, s->r_type, s->r_class,
+ s->r_ttl, &z_import->pool);
+ if (!new_rr) {
+ kr_log_error("[zscanner] line %"PRIu64": error creating rrset\n",
+ s->line_counter);
+ return -1;
+ }
+ int res = knot_rrset_add_rdata(new_rr, s->r_data, s->r_data_length,
+ &z_import->pool);
+ if (res != KNOT_EOK) {
+ kr_log_error("[zscanner] line %"PRIu64": error adding rdata to rrset\n",
+ s->line_counter);
+ return -1;
+ }
+
+ /* Records in zone file may not be grouped by name and RR type.
+ * Use map to create search key and
+ * avoid ineffective searches across all the imported records. */
+ char key[KR_RRKEY_LEN];
+ uint16_t additional_key_field = kr_rrset_type_maysig(new_rr);
+
+ res = kr_rrkey(key, new_rr->rclass, new_rr->owner, new_rr->type,
+ additional_key_field);
+ if (res <= 0) {
+ kr_log_error("[zscanner] line %"PRIu64": error constructing rrkey\n",
+ s->line_counter);
+ return -1;
+ }
+
+ knot_rrset_t *saved_rr = map_get(&z_import->rrset_indexed, key);
+ if (saved_rr) {
+ res = knot_rdataset_merge(&saved_rr->rrs, &new_rr->rrs,
+ &z_import->pool);
+ } else {
+ res = map_set(&z_import->rrset_indexed, key, new_rr);
+ }
+ if (res != 0) {
+ kr_log_error("[zscanner] line %"PRIu64": error saving parsed rrset\n",
+ s->line_counter);
+ return -1;
+ }
+
+ return 0;
+}
+
+/** @internal zscanner callback. */
+static int zi_state_parsing(zs_scanner_t *s)
+{
+ bool empty = true;
+ while (zs_parse_record(s) == 0) {
+ switch (s->state) {
+ case ZS_STATE_DATA:
+ if (zi_record_store(s) != 0) {
+ return -1;
+ }
+ zone_import_ctx_t *z_import = (zone_import_ctx_t *) s->process.data;
+ empty = false;
+ if (s->r_type == 6) {
+ z_import->origin = knot_dname_copy(s->r_owner,
+ &z_import->pool);
+ }
+ break;
+ case ZS_STATE_ERROR:
+ kr_log_error("[zscanner] line: %"PRIu64
+ ": parse error; code: %i ('%s')\n",
+ s->line_counter, s->error.code,
+ zs_strerror(s->error.code));
+ return -1;
+ case ZS_STATE_INCLUDE:
+ kr_log_error("[zscanner] line: %"PRIu64
+ ": INCLUDE is not supported\n",
+ s->line_counter);
+ return -1;
+ case ZS_STATE_EOF:
+ case ZS_STATE_STOP:
+ if (empty) {
+ kr_log_error("[zimport] empty zone file\n");
+ return -1;
+ }
+ if (!((zone_import_ctx_t *) s->process.data)->origin) {
+ kr_log_error("[zimport] zone file doesn't contain SOA record\n");
+ return -1;
+ }
+ return (s->error.counter == 0) ? 0 : -1;
+ default:
+ kr_log_error("[zscanner] line: %"PRIu64
+ ": unexpected parse state: %i\n",
+ s->line_counter, s->state);
+ return -1;
+ }
+ }
+
+ return -1;
+}
+
+int zi_zone_import(struct zone_import_ctx *z_import,
+ const char *zone_file, const char *origin,
+ uint16_t rclass, uint32_t ttl)
+{
+ assert (z_import != NULL && "[zimport] empty <z_import> parameter");
+ assert (z_import->worker != NULL && "[zimport] invalid <z_import> parameter\n");
+ assert (zone_file != NULL && "[zimport] empty <zone_file> parameter\n");
+
+ zs_scanner_t *s = malloc(sizeof(zs_scanner_t));
+ if (s == NULL) {
+ kr_log_error("[zscanner] error creating instance of zone scanner (malloc() fails)\n");
+ return -1;
+ }
+
+ /* zs_init(), zs_set_input_file(), zs_set_processing() returns -1 in case of error,
+ * so don't print error code as it meaningless. */
+ int res = zs_init(s, origin, rclass, ttl);
+ if (res != 0) {
+ kr_log_error("[zscanner] error initializing zone scanner instance, error: %i (%s)\n",
+ s->error.code, zs_strerror(s->error.code));
+ free(s);
+ return -1;
+ }
+
+ res = zs_set_input_file(s, zone_file);
+ if (res != 0) {
+ kr_log_error("[zscanner] error opening zone file `%s`, error: %i (%s)\n",
+ zone_file, s->error.code, zs_strerror(s->error.code));
+ zs_deinit(s);
+ free(s);
+ return -1;
+ }
+
+ /* Don't set processing and error callbacks as we don't use automatic parsing.
+ * Parsing as well error processing will be performed in zi_state_parsing().
+ * Store pointer to zone import context for further use. */
+ if (zs_set_processing(s, NULL, NULL, (void *)z_import) != 0) {
+ kr_log_error("[zscanner] zs_set_processing() failed for zone file `%s`, "
+ "error: %i (%s)\n",
+ zone_file, s->error.code, zs_strerror(s->error.code));
+ zs_deinit(s);
+ free(s);
+ return -1;
+ }
+
+ uint64_t elapsed = 0;
+ int ret = zi_reset(z_import, 4096);
+ if (ret == 0) {
+ z_import->started = true;
+ z_import->start_timestamp = kr_now();
+ VERBOSE_MSG(NULL, "[zscanner] started; zone file `%s`\n",
+ zone_file);
+ ret = zi_state_parsing(s);
+ if (ret == 0) {
+ /* Try to find TA for worker->z_import.origin. */
+ map_t *trust_anchors = &z_import->worker->engine->resolver.trust_anchors;
+ knot_rrset_t *rr = kr_ta_get(trust_anchors, z_import->origin);
+ if (rr) {
+ z_import->ta = rr;
+ } else {
+ /* For now - fail.
+ * TODO - query DS and continue after answer had been obtained. */
+ KR_DNAME_GET_STR(zone_name_str, z_import->origin);
+ kr_log_error("[zimport] no TA found for `%s`, fail\n", zone_name_str);
+ ret = 1;
+ }
+ elapsed = kr_now() - z_import->start_timestamp;
+ elapsed = elapsed > UINT_MAX ? UINT_MAX : elapsed;
+ }
+ }
+ zs_deinit(s);
+ free(s);
+
+ if (ret != 0) {
+ kr_log_error("[zscanner] error parsing zone file `%s`\n", zone_file);
+ z_import->started = false;
+ return ret;
+ }
+
+ VERBOSE_MSG(NULL, "[zscanner] finished in %"PRIu64" ms; zone file `%s`\n",
+ elapsed, zone_file);
+ map_walk(&z_import->rrset_indexed, zi_mapwalk_preprocess, z_import);
+
+ /* Zone have been parsed already, so start the import. */
+ uv_timer_start(&z_import->timer, zi_zone_process,
+ ZONE_IMPORT_PAUSE, ZONE_IMPORT_PAUSE);
+
+ return 0;
+}
+
+bool zi_import_started(struct zone_import_ctx *z_import)
+{
+ return z_import ? z_import->started : false;
+}
--- /dev/null
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#pragma once
+
+#include <stdbool.h>
+
+struct worker_ctx;
+/** Zone import context (opaque). */
+struct zone_import_ctx;
+
+/**
+ * Completion callback
+ *
+ * @param state -1 - fail
+ * 0 - success
+ * 1 - success, but there are non-critical errors
+ * @param pointer to user data
+ */
+typedef void (*zi_callback)(int state, void *param);
+
+/**
+ * Allocate and initialize zone import context.
+ *
+ * @param worker pointer to worker state
+ * @return NULL or pointer to zone import context.
+ */
+struct zone_import_ctx *zi_allocate(struct worker_ctx *worker,
+ zi_callback cb, void *param);
+
+/** Free zone import context. */
+void zi_free(struct zone_import_ctx *z_import);
+
+/**
+ * Import zone from file.
+ *
+ * @note only root zone import is supported; origin must be NULL or "."
+ * @param z_import pointer to zone import context
+ * @param zone_file zone file name
+ * @param origin default origin
+ * @param rclass default class
+ * @param ttl default ttl
+ * @return 0 or an error code
+ */
+int zi_zone_import(struct zone_import_ctx *z_import,
+ const char *zone_file, const char *origin,
+ uint16_t rclass, uint32_t ttl);
+
+/**
+ * Check if import already in process.
+ *
+ * @param z_import pointer to zone import context.
+ * @return true if import already in process; false otherwise.
+ */
+bool zi_import_started(struct zone_import_ctx *z_import);