From db4b099a9eb481e646acbebffe569dd35326079a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Ale=C5=A1?= Date: Thu, 11 Jun 2020 14:14:31 +0200 Subject: [PATCH] demo --- makefile | 9 + modules/sysrepo-lua/meson.build | 2 + modules/sysrepo/common/sysrepo.c | 172 ++ modules/sysrepo/common/sysrepo.h | 45 + modules/sysrepo/meson.build | 34 +- utils/client/meson.build | 24 +- utils/kres_watcher/main.c | 4 - utils/kres_watcher/meson.build | 22 - utils/kresctl/commands.c | 572 ++++++ utils/kresctl/commands.h | 60 + utils/kresctl/conf_file.c | 63 + utils/kresctl/conf_file.h | 7 + utils/kresctl/deps/ctype.h | 193 ++ utils/kresctl/deps/lookup.c | 279 +++ utils/kresctl/deps/lookup.h | 112 ++ utils/kresctl/deps/mempattern.c | 122 ++ utils/kresctl/deps/mempattern.h | 47 + utils/kresctl/deps/qp-trie/trie.c | 1451 +++++++++++++++ utils/kresctl/deps/qp-trie/trie.h | 280 +++ utils/kresctl/deps/string.c | 215 +++ utils/kresctl/deps/string.h | 93 + utils/kresctl/deps/ucw/LICENSE | 1 + utils/kresctl/deps/ucw/array-sort.h | 195 ++ utils/kresctl/deps/ucw/binsearch.h | 50 + utils/kresctl/deps/ucw/heap.c | 166 ++ utils/kresctl/deps/ucw/heap.h | 46 + utils/kresctl/deps/ucw/lists.c | 235 +++ utils/kresctl/deps/ucw/lists.h | 84 + utils/kresctl/deps/ucw/mempool.c | 323 ++++ utils/kresctl/deps/ucw/mempool.h | 124 ++ utils/kresctl/interactive.c | 170 ++ utils/kresctl/interactive.h | 7 + utils/kresctl/main.c | 121 +- utils/kresctl/meson.build | 17 +- utils/kresctl/process.c | 66 + utils/kresctl/process.h | 40 + utils/meson.build | 2 +- utils/watcher/bindings/api.h | 24 + utils/watcher/bindings/cache.c | 467 +++++ utils/watcher/bindings/event.c | 225 +++ utils/watcher/bindings/impl.c | 80 + utils/watcher/bindings/impl.h | 95 + utils/watcher/bindings/modules.c | 91 + utils/watcher/bindings/net.c | 1042 +++++++++++ utils/watcher/bindings/watcher.c | 74 + utils/watcher/bindings/worker.c | 86 + utils/watcher/dbus_control.c | 119 ++ utils/watcher/dbus_control.h | 29 + utils/watcher/engine.c | 876 +++++++++ utils/watcher/engine.h | 84 + utils/watcher/ffimodule.c | 307 +++ utils/watcher/ffimodule.h | 48 + utils/watcher/io.c | 515 ++++++ utils/watcher/io.h | 47 + utils/watcher/lua/config-watcher.lua | 5 + utils/watcher/lua/meson.build | 18 + utils/watcher/lua/sandbox-watcher.lua.in | 631 +++++++ utils/watcher/main.c | 938 ++++++++++ utils/watcher/meson.build | 56 + utils/watcher/network.c | 552 ++++++ utils/watcher/network.h | 129 ++ utils/watcher/session.c | 776 ++++++++ utils/watcher/session.h | 151 ++ utils/watcher/sr_subscriptions.c | 435 +++++ utils/watcher/sr_subscriptions.h | 10 + utils/watcher/tls.c | 1197 ++++++++++++ utils/watcher/tls.h | 242 +++ utils/watcher/tls_ephemeral_credentials.c | 249 +++ utils/watcher/tls_session_ticket-srv.c | 262 +++ utils/watcher/udp_queue.c | 170 ++ utils/watcher/udp_queue.h | 28 + utils/watcher/watcher.c | 143 ++ utils/watcher/watcher.h | 37 + utils/watcher/worker.c | 2060 +++++++++++++++++++++ utils/watcher/worker.h | 190 ++ utils/watcher/zimport.c | 821 ++++++++ utils/watcher/zimport.h | 68 + 77 files changed, 18771 insertions(+), 59 deletions(-) create mode 100644 makefile create mode 100644 modules/sysrepo/common/sysrepo.c create mode 100644 modules/sysrepo/common/sysrepo.h delete mode 100644 utils/kres_watcher/main.c delete mode 100644 utils/kres_watcher/meson.build create mode 100644 utils/kresctl/commands.c create mode 100644 utils/kresctl/commands.h create mode 100644 utils/kresctl/conf_file.c create mode 100644 utils/kresctl/conf_file.h create mode 100644 utils/kresctl/deps/ctype.h create mode 100644 utils/kresctl/deps/lookup.c create mode 100644 utils/kresctl/deps/lookup.h create mode 100644 utils/kresctl/deps/mempattern.c create mode 100644 utils/kresctl/deps/mempattern.h create mode 100644 utils/kresctl/deps/qp-trie/trie.c create mode 100644 utils/kresctl/deps/qp-trie/trie.h create mode 100644 utils/kresctl/deps/string.c create mode 100644 utils/kresctl/deps/string.h create mode 100644 utils/kresctl/deps/ucw/LICENSE create mode 100644 utils/kresctl/deps/ucw/array-sort.h create mode 100644 utils/kresctl/deps/ucw/binsearch.h create mode 100644 utils/kresctl/deps/ucw/heap.c create mode 100644 utils/kresctl/deps/ucw/heap.h create mode 100644 utils/kresctl/deps/ucw/lists.c create mode 100644 utils/kresctl/deps/ucw/lists.h create mode 100644 utils/kresctl/deps/ucw/mempool.c create mode 100644 utils/kresctl/deps/ucw/mempool.h create mode 100644 utils/kresctl/interactive.c create mode 100644 utils/kresctl/interactive.h create mode 100644 utils/kresctl/process.c create mode 100644 utils/kresctl/process.h create mode 100644 utils/watcher/bindings/api.h create mode 100644 utils/watcher/bindings/cache.c create mode 100644 utils/watcher/bindings/event.c create mode 100644 utils/watcher/bindings/impl.c create mode 100644 utils/watcher/bindings/impl.h create mode 100644 utils/watcher/bindings/modules.c create mode 100644 utils/watcher/bindings/net.c create mode 100644 utils/watcher/bindings/watcher.c create mode 100644 utils/watcher/bindings/worker.c create mode 100644 utils/watcher/dbus_control.c create mode 100644 utils/watcher/dbus_control.h create mode 100644 utils/watcher/engine.c create mode 100644 utils/watcher/engine.h create mode 100644 utils/watcher/ffimodule.c create mode 100644 utils/watcher/ffimodule.h create mode 100644 utils/watcher/io.c create mode 100644 utils/watcher/io.h create mode 100644 utils/watcher/lua/config-watcher.lua create mode 100644 utils/watcher/lua/meson.build create mode 100644 utils/watcher/lua/sandbox-watcher.lua.in create mode 100644 utils/watcher/main.c create mode 100644 utils/watcher/meson.build create mode 100644 utils/watcher/network.c create mode 100644 utils/watcher/network.h create mode 100644 utils/watcher/session.c create mode 100644 utils/watcher/session.h create mode 100644 utils/watcher/sr_subscriptions.c create mode 100644 utils/watcher/sr_subscriptions.h create mode 100644 utils/watcher/tls.c create mode 100644 utils/watcher/tls.h create mode 100644 utils/watcher/tls_ephemeral_credentials.c create mode 100644 utils/watcher/tls_session_ticket-srv.c create mode 100644 utils/watcher/udp_queue.c create mode 100644 utils/watcher/udp_queue.h create mode 100644 utils/watcher/watcher.c create mode 100644 utils/watcher/watcher.h create mode 100644 utils/watcher/worker.c create mode 100644 utils/watcher/worker.h create mode 100644 utils/watcher/zimport.c create mode 100644 utils/watcher/zimport.h diff --git a/makefile b/makefile new file mode 100644 index 000000000..9141edcc4 --- /dev/null +++ b/makefile @@ -0,0 +1,9 @@ +install: + rm -rf build + rm -rf /tmp/kr + meson build -Db_sanitize=address --prefix=/tmp/kr --default-library=static + ninja -C build + ninja install -C build + +shm_clean: + make -C /home/jetconf/kres-sysrepo/sysrepo/build shm_clean \ No newline at end of file diff --git a/modules/sysrepo-lua/meson.build b/modules/sysrepo-lua/meson.build index 05a145510..e1839fdc8 100644 --- a/modules/sysrepo-lua/meson.build +++ b/modules/sysrepo-lua/meson.build @@ -27,6 +27,8 @@ sysrepo_common_src = files([ 'common/sysrepo_conf.h', 'common/string_helper.h', 'common/string_helper.c', + 'common/sysrepo.h', + 'common/sysrepo.c', ]) c_src_lint += sysrepo_common_src diff --git a/modules/sysrepo/common/sysrepo.c b/modules/sysrepo/common/sysrepo.c new file mode 100644 index 000000000..bfc69edef --- /dev/null +++ b/modules/sysrepo/common/sysrepo.c @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "contrib/ccan/asprintf/asprintf.h" +#include "lib/utils.h" +#include "sysrepo.h" + + +int valtostr(const sr_val_t *value, char **strval) +{ + if (NULL == value) { + return 1; + } + + switch (value->type) { + case SR_CONTAINER_T: + case SR_CONTAINER_PRESENCE_T: + break; + case SR_LIST_T: + break; + case SR_STRING_T: + asprintf(strval, "%s", value->data.string_val); + break; + case SR_BOOL_T: + asprintf(strval, "%s", value->data.bool_val ? "true" : "false"); + break; + case SR_DECIMAL64_T: + asprintf(strval, "%g", value->data.decimal64_val); + break; + case SR_INT8_T: + asprintf(strval, "%"PRId8, value->data.int8_val); + break; + case SR_INT16_T: + asprintf(strval, "%"PRId16, value->data.int16_val); + break; + case SR_INT32_T: + asprintf(strval, "%"PRId32, value->data.int32_val); + break; + case SR_INT64_T: + asprintf(strval, "%"PRId64, value->data.int64_val); + break; + case SR_UINT8_T: + asprintf(strval, "%"PRIu8, value->data.uint8_val); + break; + case SR_UINT16_T: + asprintf(strval, "%"PRIu16, value->data.uint16_val); + break; + case SR_UINT32_T: + asprintf(strval, "%"PRIu32, value->data.uint32_val); + break; + case SR_UINT64_T: + asprintf(strval, "%"PRIu64, value->data.uint64_val); + break; + case SR_IDENTITYREF_T: + asprintf(strval, "%s", value->data.identityref_val); + break; + case SR_INSTANCEID_T: + asprintf(strval, "%s", value->data.instanceid_val); + break; + case SR_BITS_T: + asprintf(strval, "%s", value->data.bits_val); + break; + case SR_BINARY_T: + asprintf(strval, "%s", value->data.binary_val); + break; + case SR_ENUM_T: + asprintf(strval, "%s", value->data.enum_val); + break; + case SR_LEAF_EMPTY_T: + break; + default: + break; + } + + return 0; +} + +static void sysrepo_subscr_finish_closing(uv_handle_t *handle) +{ + sysrepo_uv_ctx_t *sysrepo = handle->data; + assert(sysrepo); + free(sysrepo); +} + +/** Free a event loop subscription. */ +static void sysrepo_subscription_free(sysrepo_uv_ctx_t *sysrepo) +{ + sr_disconnect(sysrepo->connection); + uv_close((uv_handle_t *)&sysrepo->uv_handle, sysrepo_subscr_finish_closing); +} + +static void sysrepo_subscr_cb_tramp(uv_poll_t *handle, int status, int events) +{ + sysrepo_uv_ctx_t *sysrepo = handle->data; + sysrepo->callback(sysrepo, status); +} + +static void sysrepo_subscr_cb(sysrepo_uv_ctx_t *sysrepo, int status) +{ + if (status) { + /* some error */ + return; + } + /* normal state */ + sr_process_events(sysrepo->subscription, sysrepo->session,NULL); +} + +sysrepo_uv_ctx_t *sysrepo_ctx_init() +{ + int ret = SR_ERR_OK; + sr_conn_ctx_t *sr_connection = NULL; + sr_session_ctx_t *sr_session = NULL; + sr_subscription_ctx_t *sr_subscription = NULL; + + if (!ret) ret = sr_connect(0, &sr_connection); + if (!ret) ret = sr_session_start(sr_connection, SR_DS_RUNNING, &sr_session); + if (ret){ + kr_log_error( + "[sysrepo] failed to start sysrepo session: %s\n", + sr_strerror(ret)); + return NULL; + } + + sysrepo_uv_ctx_t *sysrepo = malloc(sizeof(sysrepo_uv_ctx_t)); + sysrepo->connection = sr_connection; + sysrepo->session = sr_session; + sysrepo->callback = sysrepo_subscr_cb; + sysrepo->subscription = sr_subscription; + + return sysrepo; +} + +int sysrepo_ctx_start(uv_loop_t *loop, sysrepo_uv_ctx_t *sysrepo) +{ + int ret = SR_ERR_OK; + + int pipe; + ret = sr_get_event_pipe(sysrepo->subscription, &pipe); + if (ret != SR_ERR_OK) { + kr_log_error("[sysrepo] failed to get sysrepo event pipe: %s\n", sr_strerror(ret)); + free(sysrepo); + return ret; + } + ret = uv_poll_init(loop, &sysrepo->uv_handle, pipe); + if (ret) { + kr_log_error("[libuv] failed to initialize uv_poll: %s\n", uv_strerror(ret)); + free(sysrepo); + return ret; + } + sysrepo->uv_handle.data = sysrepo; + ret = uv_poll_start(&sysrepo->uv_handle, UV_READABLE, sysrepo_subscr_cb_tramp); + if (ret) { + kr_log_error("[libuv] failed to start uv_poll: %s\n", uv_strerror(ret)); + sysrepo_subscription_free(sysrepo); + } + return ret; +} + +int sysrepo_ctx_deinit(sysrepo_uv_ctx_t *sysrepo) +{ + sysrepo_subscription_free(sysrepo); + + return 0; +} \ No newline at end of file diff --git a/modules/sysrepo/common/sysrepo.h b/modules/sysrepo/common/sysrepo.h new file mode 100644 index 000000000..d2e1eefaa --- /dev/null +++ b/modules/sysrepo/common/sysrepo.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#define YM_COMMON "cznic-resolver-common" +#define YM_KRES "cznic-resolver-knot" +#define XPATH_BASE "/" YM_COMMON ":dns-resolver" +#define XPATH_RPC_BASE "/"YM_COMMON +#define XPATH_GC XPATH_BASE "/cache/" YM_KRES ":garbage-collector" +#define XPATH_LOG XPATH_BASE "logging/" YM_KRES "log" + + +typedef struct sysrepo_uv_ctx sysrepo_uv_ctx_t; +/** Callback for sysrepo subscriptions */ +typedef void (*sysrepo_cb)(sysrepo_uv_ctx_t *sysrepo, int status); + +/** Context for sysrepo subscriptions. + * might add some other fields in future */ +struct sysrepo_uv_ctx { + sr_conn_ctx_t *connection; + sr_session_ctx_t *session; + sr_subscription_ctx_t *subscription; + sysrepo_cb callback; + uv_poll_t uv_handle; +}; + +/** Init sysrepo context */ +sysrepo_uv_ctx_t *sysrepo_ctx_init(); + +/** Start subscribtion with sysrepo context */ +int sysrepo_ctx_start(uv_loop_t *loop, sysrepo_uv_ctx_t *sysrepo); + +/** Destroy sysrepo context */ +int sysrepo_ctx_deinit(sysrepo_uv_ctx_t *sysrepo); + +/** Logging funcion */ +int send_log(); + +/** Convert node's value to string */ +int valtostr(const sr_val_t *value, char **strval); + + + diff --git a/modules/sysrepo/meson.build b/modules/sysrepo/meson.build index a2b5726b4..4ba72d54e 100644 --- a/modules/sysrepo/meson.build +++ b/modules/sysrepo/meson.build @@ -13,22 +13,24 @@ sysrepo_common_src = files([ 'common/sysrepo_conf.h', 'common/string_helper.h', 'common/string_helper.c', + 'common/sysrepo.h', + 'common/sysrepo.c', ]) c_src_lint += sysrepo_common_src -if build_sysrepo - sysrepo_mod = shared_module( - 'sysrepo', - sysrepo_src, - sysrepo_common_src, - dependencies: [ - luajit_inc, - libyang, - libsysrepo, - ], - include_directories: mod_inc_dir, - name_prefix: '', - install: true, - install_dir: modules_dir, - ) -endif \ No newline at end of file +# if build_sysrepo +# sysrepo_mod = shared_module( +# 'sysrepo', +# sysrepo_src, +# sysrepo_common_src, +# dependencies: [ +# luajit_inc, +# libyang, +# libsysrepo, +# ], +# include_directories: mod_inc_dir, +# name_prefix: '', +# install: true, +# install_dir: modules_dir, +# ) +# endif \ No newline at end of file diff --git a/utils/client/meson.build b/utils/client/meson.build index 5c08c6c6e..80a5e4d5d 100644 --- a/utils/client/meson.build +++ b/utils/client/meson.build @@ -22,15 +22,15 @@ if get_option('client') != 'disabled' endif -if build_client - kresc = executable( - 'kresc', - kresc_src, - dependencies: [ - contrib_dep, - libedit, - ], - install: true, - install_dir: get_option('sbindir'), - ) -endif +# if build_client +# kresc = executable( +# 'kresc', +# kresc_src, +# dependencies: [ +# contrib_dep, +# libedit, +# ], +# install: true, +# install_dir: get_option('sbindir'), +# ) +# endif diff --git a/utils/kres_watcher/main.c b/utils/kres_watcher/main.c deleted file mode 100644 index d1f9a9643..000000000 --- a/utils/kres_watcher/main.c +++ /dev/null @@ -1,4 +0,0 @@ -int main(int argc, char *argv[]) -{ - return 0; -} \ No newline at end of file diff --git a/utils/kres_watcher/meson.build b/utils/kres_watcher/meson.build deleted file mode 100644 index a8fa67d4d..000000000 --- a/utils/kres_watcher/meson.build +++ /dev/null @@ -1,22 +0,0 @@ -kres_watcher_src = files([ - 'main.c', -]) -c_src_lint += kres_watcher_src - -if build_sysrepo - kres_watcher = executable( - 'kres-watcher', - kres_watcher_src, - sysrepo_common_src, - dependencies: [ - contrib_dep, - libkres_dep, - libyang, - libsysrepo, - libknot, - luajit_inc, - ], - install: true, - install_dir: get_option('sbindir'), - ) -endif \ No newline at end of file diff --git a/utils/kresctl/commands.c b/utils/kresctl/commands.c new file mode 100644 index 000000000..9e994060d --- /dev/null +++ b/utils/kresctl/commands.c @@ -0,0 +1,572 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "lib/generic/array.h" +#include "modules/sysrepo/common/string_helper.h" +#include "modules/sysrepo/common/sysrepo.h" +#include "commands.h" +#include "process.h" +#include "conf_file.h" + +/* Build-in commands */ +#define CMD_EXIT "exit" +#define CMD_HELP "help" +#define CMD_VERSION "version" +#define CMD_IMPORT "import" +#define CMD_EXPORT "export" +#define CMD_BEGIN "begin" +#define CMD_COMMIT "commit" +#define CMD_ABORT "abort" +#define CMD_VALIDATE "validate" +#define CMD_DIFF "diff" +#define CMD_PERSIST "persist" + + +static int cmd_import(cmd_args_t *args) +{ + struct lyd_node *data; + sr_session_ctx_t *sr_session = NULL; + const char *file_path = args->argv[0]; + int flags = LYD_OPT_CONFIG | LYD_OPT_TRUSTED | LYD_OPT_STRICT; + + int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sr_session); + if (ret) { + printf("failed to start sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + if (!ret) ret = step_load_data(sr_session, file_path, flags, &data); + + + /* replace config (always spends data) */ + ret = sr_replace_config(sr_session, YM_COMMON, data, 0, 0); + if (ret) { + printf("failed to replace configuration, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + ret = sr_session_stop(sr_session); + if (ret) { + printf("failed to stop sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + return CLI_EOK; +} + +static int cmd_export(cmd_args_t *args) +{ + char *xpath; + struct lyd_node *data; + sr_session_ctx_t *sr_session = NULL; + FILE *file = NULL; + + /* If argument, open file for writting */ + if (args->argc == 1) { + file = fopen(args->argv[0], "w"); + if (!file) { + printf("Failed to open \"%s\" for writing (%s)", args->argv[0], strerror(errno)); + return CLI_ECMD; + } + } + + asprintf(&xpath, "/%s:*", YM_COMMON); + int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sr_session); + if (!ret) ret = sr_get_data(sr_session, xpath, 0, 0, 0, &data); + if (ret) { + printf("failed to get configuration from sysrepo, %s\n", sr_strerror(ret)); + free(xpath); + return CLI_ECMD; + } + + ret = sr_session_stop(sr_session); + if (ret) { + printf("failed to stop sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + /* print exported data to file or stdout */ + lyd_print_file(file ? file : stdout, data, LYD_JSON, LYP_FORMAT | LYP_WITHSIBLINGS); + lyd_free_withsiblings(data); + free(xpath); + + return CLI_EOK; +} + +static int cmd_begin(cmd_args_t *args) +{ + sr_session_ctx_t *sr_session = NULL; + + if (sysrepo_ctx->session) { + printf("transaction has already begin\n"); + return CLI_ECMD; + } + + int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_CANDIDATE, &sr_session); + if (ret) { + printf("failed to start sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + ret = sr_lock(sr_session, YM_COMMON); + if (ret) { + printf("failed to lock candidate datastore, %s\n", sr_strerror(ret)); + sr_session_stop(sr_session); + return CLI_ECMD; + } + sysrepo_ctx->session = sr_session; + + return CLI_EOK; +} + +static int cmd_commmit(cmd_args_t *args) +{ + if (!sysrepo_ctx->session){ + printf("no active transaction\n"); + return CLI_ECMD; + } + + int ret = sr_validate(sysrepo_ctx->session, YM_COMMON, 0); + if (ret) { + printf("validation failed, %s\n", sr_strerror(ret)); + } + + /* switch datastore to RUNNING */ + ret = sr_session_switch_ds(sysrepo_ctx->session, SR_DS_RUNNING); + /* copy configuration from CANDIDATE to RUNNING datastore */ + if (!ret) ret = sr_copy_config(sysrepo_ctx->session, YM_COMMON, + SR_DS_CANDIDATE, args->timeout, 0); + if (ret) { + printf("commit failed, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + ret = sr_session_stop(sysrepo_ctx->session); + if (ret) { + printf("failed to stop sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + sysrepo_ctx->session = NULL; + + return CLI_EOK; +} + +static int cmd_abort(cmd_args_t *args) +{ + if (!sysrepo_ctx->session){ + printf("no active transaction\n"); + return 1; + } + + int ret = sr_session_stop(sysrepo_ctx->session); + if (ret) { + printf("failed to stop sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + sysrepo_ctx->session = NULL; + + return CLI_EOK; +} + +static int cmd_validate(cmd_args_t *args) +{ + if (!sysrepo_ctx->session){ + printf("no active transaction\n"); + return CLI_ECMD; + } + + int ret = sr_validate(sysrepo_ctx->session, YM_COMMON, args->timeout); + if (ret) { + printf("validation failed, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + return CLI_EOK; +} + +static int cmd_diff(cmd_args_t *args) +{ + return CLI_EOK; +} + +static int cmd_persist(cmd_args_t *args) +{ + sr_session_ctx_t *sr_session = NULL; + + int ret = sr_session_start(sysrepo_ctx->connection, SR_DS_STARTUP, &sr_session); + if (ret) { + printf("failed to start sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + /* copy configuration from RUNNING to STARTUP datastore */ + if (!ret) ret = sr_copy_config(sr_session, YM_COMMON, SR_DS_RUNNING, 0, 0); + if (ret) { + printf("commit failed, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + ret = sr_session_stop(sr_session); + if (ret) { + printf("failed to stop sysrepo session, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + + return CLI_EOK; +} + +/* Funtcions for YANG commands */ + +static int cmd_leaf(cmd_args_t *args) +{ + int ret = CLI_EOK; + const char *path = args->desc->xpath; + + /* get operational data */ + if (!args->argc) { + + char *strval; + sr_val_t *val = NULL; + sr_session_ctx_t *sr_session = NULL; + + if (!ret) ret = sr_session_start(sysrepo_ctx->connection, SR_DS_OPERATIONAL, &sr_session); + if (!ret) ret = sr_get_item(sr_session, path, args->timeout, &val); + if (ret) { + printf("get configuration data failed, %s\n", sr_strerror(ret)); + sr_session_stop(sr_session); + return CLI_ECMD; + } + + valtostr(val, &strval); + printf("%s = %s\n", args->desc->name, strval); + + sr_session_stop(sr_session); + sr_free_val(val); + free(strval); + } + /* set configuration data */ + else if (args->argc == 1) { + /* check if there is active session */ + if (!sysrepo_ctx->session) { + printf("no active transaction\n"); + return CLI_ECMD; + } + + if (!ret) ret = sr_session_switch_ds(sysrepo_ctx->session,SR_DS_CANDIDATE); + if (!ret) ret = sr_set_item_str(sysrepo_ctx->session, path, args->argv[0], NULL, 0); + if (!ret) ret = sr_apply_changes(sysrepo_ctx->session, 0, 0); + if (ret) { + printf("set data value failed, %s\n", sr_strerror(ret)); + return CLI_ECMD; + } + } + else { + printf("too many arguments\n"); + return CLI_ECMD; + } + + return CLI_EOK; +} + +static int cmd_container(cmd_args_t *args) +{ + int ret = CLI_EOK; + const char *path = args->desc->xpath; + + /* get operational data */ + if (!args->argc) { + + char *xpath; + struct lyd_node *data = NULL; + sr_session_ctx_t *sr_session = NULL; + + asprintf(&xpath, "%s/*//.", args->desc->xpath); + if (!ret) ret = sr_session_start(sysrepo_ctx->connection, SR_DS_OPERATIONAL, &sr_session); + if (!ret) ret = sr_get_data(sr_session, xpath, 0, args->timeout, 0, &data); + if (ret) { + printf("get configuration data failed, %s\n", sr_strerror(ret)); + sr_session_stop(sr_session); + return CLI_ECMD; + } + + lyd_print_file(stdout, data, LYD_JSON, LYP_FORMAT | LYP_WITHSIBLINGS); + + lyd_free_withsiblings(data); + sr_session_stop(sr_session); + free(xpath); + } + else { + printf("too many arguments\n"); + return CLI_ECMD; + } + + return CLI_EOK; +} + +static int cmd_list(cmd_args_t *args) +{ + return CLI_EOK; +} + +static int cmd_leaflist(cmd_args_t *args) +{ + return CLI_EOK; +} + +static int cmd_rpc(cmd_args_t *args) +{ + int ret = CLI_EOK; + sr_session_ctx_t *sr_session = NULL; + sr_val_t *output = NULL; + size_t output_count = 0; + + //TODO: prepare input + + ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sr_session); + if (!ret) ret = sr_rpc_send(sr_session, args->desc->xpath, 0, 0, args->timeout, &output, &output_count); + if (ret) { + printf("[] failed to send RPC operation, %s\n", sr_strerror(ret)); + sr_session_stop(sr_session); + return CLI_ECMD; + } + + sr_free_values(output, output_count); + sr_session_stop(sr_session); + return CLI_EOK; +} + +static int cmd_notif(cmd_args_t *args) +{ + return CLI_EOK; +} + + +static void cmd_dynarray_deep_free(cmd_dynarray_t * d) +{ + dynarray_foreach(cmd, cmd_desc_t *, i, *d) { + cmd_desc_t *cmd = *i; + free(cmd->xpath); + free(cmd->name); + free(cmd); + } + cmd_dynarray_free(d); +} + +cmd_dynarray_t dyn_cmd_table; + +const cmd_desc_t cmd_table[] = { + /* name, function, xpath, flags */ + { CMD_EXIT, NULL, "", CMD_FNONE }, + { CMD_HELP, print_commands, "", CMD_FNONE }, + { CMD_VERSION, print_version, "", CMD_FNONE }, + /* Configuration file */ + { CMD_IMPORT, cmd_import, "", CMD_FNONE }, + { CMD_EXPORT, cmd_export, "", CMD_FNONE }, + /* Transaction */ + { CMD_BEGIN, cmd_begin, "", CMD_FINTER }, + { CMD_COMMIT, cmd_commmit, "", CMD_FINTER }, + { CMD_ABORT, cmd_abort, "", CMD_FINTER }, + { CMD_VALIDATE, cmd_validate, "", CMD_FINTER }, + { CMD_DIFF, cmd_diff, "", CMD_FINTER }, + { CMD_PERSIST, cmd_persist, "", CMD_FNONE }, + /* */ + { NULL } +}; + +static void cmd_help_dynarray_deep_free(cmd_help_dynarray_t * d) +{ + dynarray_foreach(cmd_help, cmd_help_t *, i, *d) { + cmd_help_t *cmd_help = *i; + free(cmd_help); + } + cmd_help_dynarray_free(d); +} + +cmd_help_dynarray_t dyn_cmd_help_table; + +const cmd_help_t cmd_help_table[] = { + /* name, arguments, description */ + { CMD_EXIT, "", "Exit the program." }, + { CMD_HELP, "", "Print the program help." }, + { CMD_VERSION, "", "Print the program version." }, + { "", "", "" }, + { CMD_IMPORT, "", "Import YAML configuration file." }, + { CMD_EXPORT, "", "Export YAML configuration file." }, + { "", "", "" }, + { CMD_BEGIN, "", "Begin a transaction." }, + { CMD_COMMIT, "", "Commit a transaction." }, + { CMD_ABORT, "", "Abort a transaction." }, + { CMD_VALIDATE, "", "Validate a transaction changes." }, + { CMD_DIFF, "", "Show configuration changes." }, + { CMD_PERSIST, "", "Make running configuration persist during system reboots." }, + { NULL } +}; + +static const char *create_cmd_name(const char* xpath) +{ + char* name = (char*)malloc(strlen(xpath)+1); + if (!name){ + // memory allocation failed. + return ""; + } + strcpy(name,xpath); + + /* remove modules from name, the order is important */ + remove_substr(name, XPATH_BASE"/"); + remove_substr(name, XPATH_BASE); + remove_substr(name, "/"YM_COMMON":"); + remove_substr(name, YM_COMMON":"); + remove_substr(name, YM_KRES":"); + /* replace '/' with '.' */ + replace_char(name, '/', '.'); + + return name; +} + +static int create_cmd(struct lys_node *node) +{ + if (node->nodetype == LYS_GROUPING || + node->nodetype == LYS_USES) { + return 0; + } + + const char *xpath = lys_data_path(node); + const char *name = create_cmd_name(xpath); + + if (!strlen(name)) { + free(xpath); + free(name); + return 0; + } + + cmd_desc_t *cmd = malloc(sizeof(cmd_desc_t)); + cmd->name = name; + cmd->xpath = xpath; + cmd->fcn = &cmd_leaf; + cmd->flags = CMD_FNONE; + + switch (node->nodetype) { + case LYS_CONTAINER: + cmd->fcn = &cmd_container; + break; + case LYS_LEAF: + cmd->fcn = &cmd_leaf; + break; + case LYS_LEAFLIST: + cmd->fcn = &cmd_leaflist; + break; + case LYS_LIST: + cmd->fcn = &cmd_list; + break; + case LYS_ACTION: + cmd->fcn = &cmd_rpc; + break; + case LYS_RPC: + cmd->fcn = &cmd_rpc; + break; + case LYS_NOTIF: + cmd->fcn = &cmd_notif; + break; + default: + cmd->fcn = &cmd_leaf; + break; + } + + cmd_help_t *cmd_help = malloc(sizeof(cmd_help_t)); + cmd_help->name = name; + cmd_help->desc = node->dsc; + cmd_help->params = ""; + + cmd_help_dynarray_add(&dyn_cmd_help_table, &cmd_help); + cmd_dynarray_add(&dyn_cmd_table, &cmd); + + return CLI_EOK; +} + +static void schema_iterator(struct lys_node *root) +{ + assert(root != NULL); + + struct lys_node *node = NULL; + + LY_TREE_FOR(root, node) { + assert(node != NULL); + + create_cmd(node); + + /* do childs only for CONTAINERS, ignore others */ + if (node->child + && (node->nodetype != LYS_LIST) + && (node->nodetype != LYS_RPC) + && (node->nodetype != LYS_ACTION) + && (node->nodetype != LYS_NOTIF) + ) { + schema_iterator(node->child); + } + } +} + +int create_cmd_table(sr_conn_ctx_t *sr_connection) +{ + assert(sr_connection != NULL); + + int ret = CLI_EOK; + struct lys_node *root = NULL; + struct ly_ctx *ly_context = NULL; + struct lys_module *module = NULL; + + ly_context = sr_get_context(sr_connection); + if (!ly_context) { + printf("[] failed to get libyang context\n"); + return CLI_ERR; + } + + /* get libyang context */ + root = ly_ctx_get_node(ly_context, NULL, XPATH_BASE, 0); + assert(root != NULL); + + /* iterate thrue all schema nodes */ + schema_iterator(root); + + return CLI_EOK; +} + +void destroy_cmd_table() +{ + cmd_dynarray_deep_free(&dyn_cmd_table); + cmd_help_dynarray_deep_free(&dyn_cmd_help_table); +} + +int print_version(cmd_args_t *args) +{ + printf("%s (%s), version %s\n", PROGRAM_NAME, PROJECT_NAME, PACKAGE_VERSION); + + return 0; +} + +int print_commands(cmd_args_t *args) +{ + printf("\nCommands:\n"); + + /* Print all build-in commands */ + for (const cmd_help_t *cmd = cmd_help_table; cmd->name != NULL; cmd++) { + printf(" %-15s %-15s %s\n", cmd->name, cmd->params, cmd->desc); + } + printf("\n"); + + /* Print all created commands */ + dynarray_foreach(cmd_help, cmd_help_t *, i, dyn_cmd_help_table) { + cmd_help_t *cmd_help = *i; + printf(" %-40s %-10s %s\n", cmd_help->name, cmd_help->params, cmd_help->desc); + } + + printf("\n" + "Note:\n" + ""); + + return 0; +} diff --git a/utils/kresctl/commands.h b/utils/kresctl/commands.h new file mode 100644 index 000000000..65cc4e6fe --- /dev/null +++ b/utils/kresctl/commands.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include + +#include "contrib/dynarray.h" +#include "process.h" + + +struct cmd_desc; +typedef struct cmd_desc cmd_desc_t; +typedef struct cmd_help cmd_help_t; + +typedef enum { + CMD_FNONE = 0, + CMD_FINTER = 1 << 0, /* Interactive-only command. */ + CMD_FSTATE = 2 << 0, /* State-only command */ +} cmd_flag_t; + +typedef struct { + const cmd_desc_t *desc; + const sr_datastore_t ds; + int argc; + const char **argv; + int timeout; +} cmd_args_t; + +struct cmd_desc { + const char *name; + int (*fcn)(cmd_args_t *); + const char *xpath; + cmd_flag_t flags; +}; + +struct cmd_help { + const char *name; + const char *params; + const char *desc; +}; + +dynarray_declare(cmd, cmd_desc_t *, DYNARRAY_VISIBILITY_STATIC, 0) + dynarray_define(cmd, cmd_desc_t *, DYNARRAY_VISIBILITY_STATIC) + +dynarray_declare(cmd_help, cmd_help_t *, DYNARRAY_VISIBILITY_STATIC, 0) + dynarray_define(cmd_help, cmd_help_t *, DYNARRAY_VISIBILITY_STATIC) + +int create_cmd_table(sr_conn_ctx_t *sr_connection); + +void destroy_cmd_table(); + +int print_version(cmd_args_t *args); + +int print_commands(cmd_args_t *args); + +extern cmd_dynarray_t dyn_cmd_table; +extern cmd_help_dynarray_t dyn_cmd_help_table; + +extern const cmd_desc_t cmd_table[]; +extern const cmd_help_t cmd_help_table[]; \ No newline at end of file diff --git a/utils/kresctl/conf_file.c b/utils/kresctl/conf_file.c new file mode 100644 index 000000000..bebdd616f --- /dev/null +++ b/utils/kresctl/conf_file.c @@ -0,0 +1,63 @@ +#include +#include +#include +#include + +#include "conf_file.h" + + +static int step_read_file(FILE *file, char **mem) +{ + size_t mem_size, mem_used; + + mem_size = 512; + mem_used = 0; + *mem = malloc(mem_size); + + do { + if (mem_used == mem_size) { + mem_size >>= 1; + *mem = realloc(*mem, mem_size); + } + + mem_used += fread(*mem + mem_used, 1, mem_size - mem_used, file); + } while (mem_used == mem_size); + + if (ferror(file)) { + free(*mem); + printf("Error reading from file (%s)\n", strerror(errno)); + return EXIT_FAILURE; + } else if (!feof(file)) { + free(*mem); + printf("Unknown file problem\n"); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +int step_load_data(sr_session_ctx_t *sr_session, const char *file_path, int flags, struct lyd_node **data) +{ + struct ly_ctx *ly_ctx; + char *ptr; + + ly_ctx = (struct ly_ctx *)sr_get_context(sr_session_get_connection(sr_session)); + + /* parse import data */ + if (file_path) { + *data = lyd_parse_path(ly_ctx, file_path, LYD_JSON, flags, NULL); + } else { + /* need to load the data into memory first */ + if (step_read_file(stdin, &ptr)) { + return EXIT_FAILURE; + } + *data = lyd_parse_mem(ly_ctx, ptr, LYD_JSON, flags); + free(ptr); + } + if (ly_errno) { + printf("Data parsing failed\n"); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/utils/kresctl/conf_file.h b/utils/kresctl/conf_file.h new file mode 100644 index 000000000..cfcdffb09 --- /dev/null +++ b/utils/kresctl/conf_file.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include + + +int step_load_data(sr_session_ctx_t *sess, const char *file_path, int flags, struct lyd_node **data); \ No newline at end of file diff --git a/utils/kresctl/deps/ctype.h b/utils/kresctl/deps/ctype.h new file mode 100644 index 000000000..d1e9d274c --- /dev/null +++ b/utils/kresctl/deps/ctype.h @@ -0,0 +1,193 @@ +/* Copyright (C) 2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +/*! + * \brief Locale-independent ctype functions. + */ + +#pragma once + +#include +#include +#include + +enum { + CT_DIGIT = 1 << 0, + CT_UPPER = 1 << 1, + CT_LOWER = 1 << 2, + CT_XDIGT = 1 << 3, + CT_PUNCT = 1 << 4, + CT_PRINT = 1 << 5, + CT_SPACE = 1 << 6, +}; + +static const uint8_t char_mask[256] = { + // 0 - 8 + ['\t'] = CT_SPACE, + ['\n'] = CT_SPACE, + ['\v'] = CT_SPACE, + ['\f'] = CT_SPACE, + ['\r'] = CT_SPACE, + // 14 - 31 + [' '] = CT_PRINT | CT_SPACE, + + ['!'] = CT_PRINT | CT_PUNCT, + ['"'] = CT_PRINT | CT_PUNCT, + ['#'] = CT_PRINT | CT_PUNCT, + ['$'] = CT_PRINT | CT_PUNCT, + ['%'] = CT_PRINT | CT_PUNCT, + ['&'] = CT_PRINT | CT_PUNCT, + ['\''] = CT_PRINT | CT_PUNCT, + ['('] = CT_PRINT | CT_PUNCT, + [')'] = CT_PRINT | CT_PUNCT, + ['*'] = CT_PRINT | CT_PUNCT, + ['+'] = CT_PRINT | CT_PUNCT, + [','] = CT_PRINT | CT_PUNCT, + ['-'] = CT_PRINT | CT_PUNCT, + ['.'] = CT_PRINT | CT_PUNCT, + ['/'] = CT_PRINT | CT_PUNCT, + + ['0'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['1'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['2'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['3'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['4'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['5'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['6'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['7'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['8'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + ['9'] = CT_PRINT | CT_DIGIT | CT_XDIGT, + + [':'] = CT_PRINT | CT_PUNCT, + [';'] = CT_PRINT | CT_PUNCT, + ['<'] = CT_PRINT | CT_PUNCT, + ['='] = CT_PRINT | CT_PUNCT, + ['>'] = CT_PRINT | CT_PUNCT, + ['?'] = CT_PRINT | CT_PUNCT, + ['@'] = CT_PRINT | CT_PUNCT, + + ['A'] = CT_PRINT | CT_UPPER | CT_XDIGT, + ['B'] = CT_PRINT | CT_UPPER | CT_XDIGT, + ['C'] = CT_PRINT | CT_UPPER | CT_XDIGT, + ['D'] = CT_PRINT | CT_UPPER | CT_XDIGT, + ['E'] = CT_PRINT | CT_UPPER | CT_XDIGT, + ['F'] = CT_PRINT | CT_UPPER | CT_XDIGT, + ['G'] = CT_PRINT | CT_UPPER, + ['H'] = CT_PRINT | CT_UPPER, + ['I'] = CT_PRINT | CT_UPPER, + ['J'] = CT_PRINT | CT_UPPER, + ['K'] = CT_PRINT | CT_UPPER, + ['L'] = CT_PRINT | CT_UPPER, + ['M'] = CT_PRINT | CT_UPPER, + ['N'] = CT_PRINT | CT_UPPER, + ['O'] = CT_PRINT | CT_UPPER, + ['P'] = CT_PRINT | CT_UPPER, + ['Q'] = CT_PRINT | CT_UPPER, + ['R'] = CT_PRINT | CT_UPPER, + ['S'] = CT_PRINT | CT_UPPER, + ['T'] = CT_PRINT | CT_UPPER, + ['U'] = CT_PRINT | CT_UPPER, + ['V'] = CT_PRINT | CT_UPPER, + ['W'] = CT_PRINT | CT_UPPER, + ['X'] = CT_PRINT | CT_UPPER, + ['Y'] = CT_PRINT | CT_UPPER, + ['Z'] = CT_PRINT | CT_UPPER, + + ['['] = CT_PRINT | CT_PUNCT, + ['\\'] = CT_PRINT | CT_PUNCT, + [']'] = CT_PRINT | CT_PUNCT, + ['^'] = CT_PRINT | CT_PUNCT, + ['_'] = CT_PRINT | CT_PUNCT, + ['`'] = CT_PRINT | CT_PUNCT, + + ['a'] = CT_PRINT | CT_LOWER | CT_XDIGT, + ['b'] = CT_PRINT | CT_LOWER | CT_XDIGT, + ['c'] = CT_PRINT | CT_LOWER | CT_XDIGT, + ['d'] = CT_PRINT | CT_LOWER | CT_XDIGT, + ['e'] = CT_PRINT | CT_LOWER | CT_XDIGT, + ['f'] = CT_PRINT | CT_LOWER | CT_XDIGT, + ['g'] = CT_PRINT | CT_LOWER, + ['h'] = CT_PRINT | CT_LOWER, + ['i'] = CT_PRINT | CT_LOWER, + ['j'] = CT_PRINT | CT_LOWER, + ['k'] = CT_PRINT | CT_LOWER, + ['l'] = CT_PRINT | CT_LOWER, + ['m'] = CT_PRINT | CT_LOWER, + ['n'] = CT_PRINT | CT_LOWER, + ['o'] = CT_PRINT | CT_LOWER, + ['p'] = CT_PRINT | CT_LOWER, + ['q'] = CT_PRINT | CT_LOWER, + ['r'] = CT_PRINT | CT_LOWER, + ['s'] = CT_PRINT | CT_LOWER, + ['t'] = CT_PRINT | CT_LOWER, + ['u'] = CT_PRINT | CT_LOWER, + ['v'] = CT_PRINT | CT_LOWER, + ['w'] = CT_PRINT | CT_LOWER, + ['x'] = CT_PRINT | CT_LOWER, + ['y'] = CT_PRINT | CT_LOWER, + ['z'] = CT_PRINT | CT_LOWER, + + ['{'] = CT_PRINT | CT_PUNCT, + ['|'] = CT_PRINT | CT_PUNCT, + ['}'] = CT_PRINT | CT_PUNCT, + ['~'] = CT_PRINT | CT_PUNCT, + // 127 - 255 +}; + +static inline bool is_alnum(uint8_t c) +{ + return char_mask[c] & (CT_DIGIT | CT_UPPER | CT_LOWER); +} + +static inline bool is_alpha(uint8_t c) +{ + return char_mask[c] & (CT_UPPER | CT_LOWER); +} + +static inline bool is_digit(uint8_t c) +{ + return char_mask[c] & CT_DIGIT; +} + +static inline bool is_xdigit(uint8_t c) +{ + return char_mask[c] & CT_XDIGT; +} + +static inline bool is_lower(uint8_t c) +{ + return char_mask[c] & CT_LOWER; +} + +static inline bool is_upper(uint8_t c) +{ + return char_mask[c] & CT_UPPER; +} + +static inline bool is_print(uint8_t c) +{ + return char_mask[c] & CT_PRINT; +} + +static inline bool is_punct(uint8_t c) +{ + return char_mask[c] & CT_PUNCT; +} + +static inline bool is_space(uint8_t c) +{ + return char_mask[c] & CT_SPACE; +} diff --git a/utils/kresctl/deps/lookup.c b/utils/kresctl/deps/lookup.c new file mode 100644 index 000000000..5a20d9e48 --- /dev/null +++ b/utils/kresctl/deps/lookup.c @@ -0,0 +1,279 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include + +#include "lookup.h" +#include "mempattern.h" +#include "contrib/ucw/mempool.h" +#include "libknot/error.h" + +int lookup_init(lookup_t *lookup) +{ + if (lookup == NULL) { + return KNOT_EINVAL; + } + memset(lookup, 0, sizeof(*lookup)); + + mm_ctx_mempool(&lookup->mm, MM_DEFAULT_BLKSIZE); + lookup->trie = trie_create(&lookup->mm); + if (lookup->trie == NULL) { + mp_delete(lookup->mm.ctx); + return KNOT_ENOMEM; + } + + return KNOT_EOK; +} + +static void reset_output(lookup_t *lookup) +{ + if (lookup == NULL) { + return; + } + + mm_free(&lookup->mm, lookup->found.key); + lookup->found.key = NULL; + lookup->found.data = NULL; + + lookup->iter.count = 0; + + mm_free(&lookup->mm, lookup->iter.first_key); + lookup->iter.first_key = NULL; + + trie_it_free(lookup->iter.it); + lookup->iter.it = NULL; +} + +void lookup_deinit(lookup_t *lookup) +{ + if (lookup == NULL) { + return; + } + + reset_output(lookup); + + trie_free(lookup->trie); + mp_delete(lookup->mm.ctx); +} + +int lookup_insert(lookup_t *lookup, const char *str, void *data) +{ + if (lookup == NULL || str == NULL) { + return KNOT_EINVAL; + } + + size_t str_len = strlen(str); + if (str_len == 0) { + return KNOT_EINVAL; + } + + trie_val_t *val = trie_get_ins(lookup->trie, (const trie_key_t *)str, str_len); + if (val == NULL) { + return KNOT_ENOMEM; + } + *val = data; + + return KNOT_EOK; +} + +static int set_key(lookup_t *lookup, char **dst, const char *key, size_t key_len) +{ + if (*dst != NULL) { + mm_free(&lookup->mm, *dst); + } + *dst = mm_alloc(&lookup->mm, key_len + 1); + if (*dst == NULL) { + return KNOT_ENOMEM; + } + memcpy(*dst, key, key_len); + (*dst)[key_len] = '\0'; + + return KNOT_EOK; +} + +int lookup_search(lookup_t *lookup, const char *str, size_t str_len) +{ + if (lookup == NULL) { + return KNOT_EINVAL; + } + + // Change NULL string to the empty one. + if (str == NULL) { + str = ""; + } + + reset_output(lookup); + + size_t new_len = 0; + trie_it_t *it = trie_it_begin(lookup->trie); + for (; !trie_it_finished(it); trie_it_next(it)) { + size_t len; + const char *key = (const char *)trie_it_key(it, &len); + + // Compare with a shorter key. + if (len < str_len) { + int ret = memcmp(str, key, len); + if (ret >= 0) { + continue; + } else { + break; + } + } + + // Compare with an equal length or longer key. + int ret = memcmp(str, key, str_len); + if (ret == 0) { + lookup->iter.count++; + + // First candidate. + if (lookup->iter.count == 1) { + ret = set_key(lookup, &lookup->found.key, key, len); + if (ret != KNOT_EOK) { + break; + } + lookup->found.data = *trie_it_val(it); + new_len = len; + // Another candidate. + } else if (new_len > str_len) { + if (new_len > len) { + new_len = len; + } + while (memcmp(lookup->found.key, key, new_len) != 0) { + new_len--; + } + } + // Stop if greater than the key, and also than all the following keys. + } else if (ret < 0) { + break; + } + } + trie_it_free(it); + + switch (lookup->iter.count) { + case 0: + return KNOT_ENOENT; + case 1: + return KNOT_EOK; + default: + // Store full name of the first candidate. + if (set_key(lookup, &lookup->iter.first_key, lookup->found.key, + strlen(lookup->found.key)) != KNOT_EOK) { + return KNOT_ENOMEM; + } + lookup->found.key[new_len] = '\0'; + lookup->found.data = NULL; + + return KNOT_EFEWDATA; + } +} + +void lookup_list(lookup_t *lookup) +{ + if (lookup == NULL || lookup->iter.first_key == NULL) { + return; + } + + if (lookup->iter.it != NULL) { + if (trie_it_finished(lookup->iter.it)) { + trie_it_free(lookup->iter.it); + lookup->iter.it = NULL; + return; + } + + trie_it_next(lookup->iter.it); + + size_t len; + const char *key = (const char *)trie_it_key(lookup->iter.it, &len); + + int ret = set_key(lookup, &lookup->found.key, key, len); + if (ret == KNOT_EOK) { + lookup->found.data = *trie_it_val(lookup->iter.it); + } + return; + } + + lookup->iter.it = trie_it_begin(lookup->trie); + while (!trie_it_finished(lookup->iter.it)) { + size_t len; + const char *key = (const char *)trie_it_key(lookup->iter.it, &len); + + if (strncmp(key, lookup->iter.first_key, len) == 0) { + int ret = set_key(lookup, &lookup->found.key, key, len); + if (ret == KNOT_EOK) { + lookup->found.data = *trie_it_val(lookup->iter.it); + } + break; + } + trie_it_next(lookup->iter.it); + } +} + +static void print_options(lookup_t *lookup, EditLine *el) +{ + // Get terminal lines. + unsigned lines = 0; + if (el_get(el, EL_GETTC, "li", &lines) != 0 || lines < 3) { + return; + } + + for (size_t i = 1; i <= lookup->iter.count; i++) { + lookup_list(lookup); + printf("\n%s", lookup->found.key); + + if (i > 1 && i % (lines - 1) == 0 && i < lookup->iter.count) { + printf("\n Display next from %zu possibilities? (y or n)", + lookup->iter.count); + char next; + el_getc(el, &next); + if (next != 'y') { + break; + } + } + } + + printf("\n"); + fflush(stdout); +} + +void lookup_complete(lookup_t *lookup, const char *str, size_t str_len, + EditLine *el, bool add_space) +{ + if (lookup == NULL || el == NULL) { + return; + } + + // Try to complete the command name. + int ret = lookup_search(lookup, str, str_len); + switch (ret) { + case KNOT_EOK: + el_deletestr(el, str_len); + el_insertstr(el, lookup->found.key); + if (add_space) { + el_insertstr(el, " "); + } + break; + case KNOT_EFEWDATA: + if (strlen(lookup->found.key) > str_len) { + el_deletestr(el, str_len); + el_insertstr(el, lookup->found.key); + } else { + print_options(lookup, el); + } + break; + default: + break; + } +} \ No newline at end of file diff --git a/utils/kresctl/deps/lookup.h b/utils/kresctl/deps/lookup.h new file mode 100644 index 000000000..fafe9dd05 --- /dev/null +++ b/utils/kresctl/deps/lookup.h @@ -0,0 +1,112 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include + +#include "libknot/mm_ctx.h" +#include "qp-trie/trie.h" + +/*! Lookup context. */ +typedef struct { + /*! Memory pool context. */ + knot_mm_t mm; + /*! Main trie storage. */ + trie_t *trie; + + /*! Current (iteration) data context. */ + struct { + /*! Stored key. */ + char *key; + /*! Corresponding key data. */ + void *data; + } found; + + /*! Iteration context. */ + struct { + /*! Total number of possibilies. */ + size_t count; + /*! The first possibility. */ + char *first_key; + /*! Hat-trie iterator. */ + trie_it_t *it; + } iter; +} lookup_t; + +/*! + * Initializes the lookup context. + * + * \param[in] lookup Lookup context. + * + * \return Error code, KNOT_EOK if successful. + */ +int lookup_init(lookup_t *lookup); + +/*! + * Deinitializes the lookup context. + * + * \param[in] lookup Lookup context. + */ +void lookup_deinit(lookup_t *lookup); + +/*! + * Inserts given key and data into the lookup. + * + * \param[in] lookup Lookup context. + * \param[in] str Textual key. + * \param[in] data Key textual data. + * + * \return Error code, KNOT_EOK if successful. + */ +int lookup_insert(lookup_t *lookup, const char *str, void *data); + +/*! + * Searches the lookup container for the given key. + * + * \note If one candidate, lookup.found contains the key/data, + * if more candidates, lookup.found contains the common key prefix and + * lookup.iter.first_key is the first candidate key. + * + * \param[in] lookup Lookup context. + * \param[in] str Textual key. + * \param[in] str_len Textual key length. + * + * \return Error code, KNOT_EOK if 1 candidate, KNOT_ENOENT if no candidate, + * and KNOT_EFEWDATA if more candidates are possible. + */ +int lookup_search(lookup_t *lookup, const char *str, size_t str_len); + +/*! + * Moves the lookup iterator to the next key candidate. + * + * \note lookup.found is updated. + * + * \param[in] lookup Lookup context. + */ +void lookup_list(lookup_t *lookup); + +/*! + * Completes the string based on the lookup content or prints all candidates. + * + * \param[in] lookup Lookup context. + * \param[in] str Textual key. + * \param[in] str_len Textual key length. + * \param[in] el Editline context. + * \param[in] add_space Add one space after completed string flag. + */ +void lookup_complete(lookup_t *lookup, const char *str, size_t str_len, + EditLine *el, bool add_space); \ No newline at end of file diff --git a/utils/kresctl/deps/mempattern.c b/utils/kresctl/deps/mempattern.c new file mode 100644 index 000000000..c6cc4a21f --- /dev/null +++ b/utils/kresctl/deps/mempattern.c @@ -0,0 +1,122 @@ +/* Copyright (C) 2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include + +#include "mempattern.h" +#include "string.h" +#include "ucw/mempool.h" + +static void mm_nofree(void *p) +{ + /* nop */ +} + +static void *mm_malloc(void *ctx, size_t n) +{ + (void)ctx; + return malloc(n); +} + +void *mm_alloc(knot_mm_t *mm, size_t size) +{ + if (mm) { + return mm->alloc(mm->ctx, size); + } else { + return malloc(size); + } +} + +void *mm_calloc(knot_mm_t *mm, size_t nmemb, size_t size) +{ + if (nmemb == 0 || size == 0) { + return NULL; + } + if (mm) { + size_t total_size = nmemb * size; + if (total_size / nmemb != size) { // Overflow check + return NULL; + } + void *mem = mm_alloc(mm, total_size); + if (mem == NULL) { + return NULL; + } + return memzero(mem, total_size); + } else { + return calloc(nmemb, size); + } +} +/* +void *mm_realloc(knot_mm_t *mm, void *what, size_t size, size_t prev_size) +{ + if (mm) { + void *p = mm->alloc(mm->ctx, size); + if (p == NULL) { + return NULL; + } else { + if (what) { + memcpy(p, what, + prev_size < size ? prev_size : size); + } + mm_free(mm, what); + return p; + } + } else { + return realloc(what, size); + } +} +*/ +char *mm_strdup(knot_mm_t *mm, const char *s) +{ + if (s == NULL) { + return NULL; + } + if (mm) { + size_t len = strlen(s) + 1; + void *mem = mm_alloc(mm, len); + if (mem == NULL) { + return NULL; + } + return memcpy(mem, s, len); + } else { + return strdup(s); + } +} + +void mm_free(knot_mm_t *mm, void *what) +{ + if (mm) { + if (mm->free) { + mm->free(what); + } + } else { + free(what); + } +} + +void mm_ctx_init(knot_mm_t *mm) +{ + mm->ctx = NULL; + mm->alloc = mm_malloc; + mm->free = free; +} + +void mm_ctx_mempool(knot_mm_t *mm, size_t chunk_size) +{ + mm->ctx = mp_new(chunk_size); + mm->alloc = (knot_mm_alloc_t)mp_alloc; + mm->free = mm_nofree; +} diff --git a/utils/kresctl/deps/mempattern.h b/utils/kresctl/deps/mempattern.h new file mode 100644 index 000000000..95e1128f2 --- /dev/null +++ b/utils/kresctl/deps/mempattern.h @@ -0,0 +1,47 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +/*! + * \brief Memory allocation related functions. + */ + +#pragma once + +#include "libknot/mm_ctx.h" + +/*! \brief Default memory block size. */ +#define MM_DEFAULT_BLKSIZE 4096 + +/*! \brief Allocs using 'mm' if any, uses system malloc() otherwise. */ +void *mm_alloc(knot_mm_t *mm, size_t size); + +/*! \brief Callocs using 'mm' if any, uses system calloc() otherwise. */ +void *mm_calloc(knot_mm_t *mm, size_t nmemb, size_t size); + +/*! \brief Reallocs using 'mm' if any, uses system realloc() otherwise. */ +//void *mm_realloc(knot_mm_t *mm, void *what, size_t size, size_t prev_size); + +/*! \brief Strdups using 'mm' if any, uses system strdup() otherwise. */ +char *mm_strdup(knot_mm_t *mm, const char *s); + +/*! \brief Free using 'mm' if any, uses system free() otherwise. */ +void mm_free(knot_mm_t *mm, void *what); + +/*! \brief Initialize default memory allocation context. */ +void mm_ctx_init(knot_mm_t *mm); + +/*! \brief Memory pool context. */ +void mm_ctx_mempool(knot_mm_t *mm, size_t chunk_size); diff --git a/utils/kresctl/deps/qp-trie/trie.c b/utils/kresctl/deps/qp-trie/trie.c new file mode 100644 index 000000000..c5ee89769 --- /dev/null +++ b/utils/kresctl/deps/qp-trie/trie.c @@ -0,0 +1,1451 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + Copyright (C) 2018 Tony Finch + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + + The code originated from https://github.com/fanf2/qp/blob/master/qp.c + at revision 5f6d93753. + */ + +#include +#include +#include +#include + +#include "contrib/qp-trie/trie.h" +#include "contrib/macros.h" +#include "contrib/mempattern.h" +#include "libknot/errcode.h" + +typedef unsigned int uint; +typedef uint64_t index_t; /*!< nibble index into a key */ +typedef uint64_t word; /*!< A type-punned word */ +typedef uint bitmap_t; /*!< Bit-maps, using the range of 1<<0 to 1<<16 (inclusive). */ + +typedef char static_assert_pointer_fits_in_word + [sizeof(word) >= sizeof(uintptr_t) ? 1 : -1]; + +#define KEYLENBITS 31 + +/*! \brief trie keys have lengths + * + * 32 bits are enough for key lengths; probably even 16 bits would be. + * However, a 32 bit length means the alignment will be a multiple of + * 4, allowing us to stash the COW and BRANCH flags in the bottom bits + * of a pointer to a key. + * + * We need to steal a couple of bits from the length to keep the COW + * state of key allocations. + */ +typedef struct { + uint32_t cow:1, len:KEYLENBITS; + trie_key_t chars[]; +} tkey_t; + +/*! \brief A trie node is a pair of words. + * + * Each word is type-punned, depending on whether this is a branch + * node or a leaf node. We'll define some accessor functions to wrap + * this up into something reasonably safe. + * + * We aren't using a union to avoid problems with strict aliasing, and + * we aren't using bitfields because we want to control exactly which + * bits in the word are used by each field (in particular the flags). + * + * Branch nodes are never allocated individually: they are always part + * of either the root node or the twigs array of their parent branch. + * + * In a branch: + * + * `i` contains flags, bitmap, and index, explained in more detail below. + * + * `p` is a pointer to the "twigs", an array of child nodes. + * + * In a leaf: + * + * `i` is cast from a pointer to a tkey_t, with flags in the bottom bits. + * + * `p` is a trie_val_t. + */ +typedef struct node { + word i; + void *p; +} node_t; + +struct trie { + node_t root; // undefined when weight == 0, see empty_root() + size_t weight; + knot_mm_t mm; +}; + +/*! \brief size (in bits) of nibble (half-byte) indexes into keys + * + * The bottom bit is clear for the upper nibble, and set for the lower + * nibble, big-endian style, since the tree has to be in lexicographic + * order. The index increases from one branch node to the next as you + * go deeper into the trie. All the keys below a branch are identical + * up to the nibble identified by the branch. + * + * (see also tkey_t.len above) + */ +#define TWIDTH_INDEX 33 + +/*! \brief exclusive limit on indexes */ +#define TMAX_INDEX (BIG1 << TWIDTH_INDEX) + +/*! \brief size (in bits) of branch bitmap + * + * The bitmap indicates which subtries are present. The present child + * nodes are stored in the twigs array (with no holes between them). + * + * To simplify storing keys that are prefixes of each other, the + * end-of-string position is treated as an extra nibble value, ordered + * before all others. So there are 16 possible real nibble values, + * plus one value for nibbles past the end of the key. + */ +#define TWIDTH_BMP 17 + +/* + * We're constructing the layout of the branch `i` field in a careful + * way to avoid mistakes, getting the compiler to calculate values + * rather than typing them in by hand. + */ +enum { + TSHIFT_BRANCH = 0, + TSHIFT_COW, + TSHIFT_BMP, + TOP_BMP = TSHIFT_BMP + TWIDTH_BMP, + TSHIFT_INDEX = TOP_BMP, + TOP_INDEX = TSHIFT_INDEX + TWIDTH_INDEX, +}; + +typedef char static_assert_fields_fit_in_word + [TOP_INDEX <= sizeof(word) * CHAR_BIT ? 1 : -1]; + +typedef char static_assert_bmp_fits + [TOP_BMP <= sizeof(bitmap_t) * CHAR_BIT ? 1 : -1]; + +#define BIG1 ((word)1) +#define TMASK(width, shift) (((BIG1 << (width)) - BIG1) << (shift)) + +/*! \brief is this node a branch or a leaf? */ +#define TFLAG_BRANCH (BIG1 << TSHIFT_BRANCH) + +/*! \brief copy-on-write flag, used in both leaves and branches */ +#define TFLAG_COW (BIG1 << TSHIFT_COW) + +/*! \brief for extracting pointer to key */ +#define TMASK_LEAF (~(word)(TFLAG_BRANCH | TFLAG_COW)) + +/*! \brief mask for extracting nibble index */ +#define TMASK_INDEX TMASK(TWIDTH_INDEX, TSHIFT_INDEX) + +/*! \brief mask for extracting bitmap */ +#define TMASK_BMP TMASK(TWIDTH_BMP, TSHIFT_BMP) + +/*! \brief bitmap entry for NOBYTE */ +#define BMP_NOBYTE (BIG1 << TSHIFT_BMP) + +/*! \brief Initialize a new leaf, copying the key, and returning failure code. */ +static int mkleaf(node_t *leaf, const trie_key_t *key, uint32_t len, knot_mm_t *mm) +{ + if (unlikely((word)len > (BIG1 << KEYLENBITS))) + return KNOT_ENOMEM; + tkey_t *lkey = mm_alloc(mm, sizeof(tkey_t) + len); + if (unlikely(!lkey)) + return KNOT_ENOMEM; + lkey->cow = 0; + lkey->len = len; + memcpy(lkey->chars, key, len); + word i = (uintptr_t)lkey; + assert((i & TFLAG_BRANCH) == 0); + *leaf = (node_t){ .i = i, .p = NULL }; + return KNOT_EOK; +} + +/*! \brief construct a branch node */ +static node_t mkbranch(index_t index, bitmap_t bmp, node_t *twigs) +{ + word i = TFLAG_BRANCH | bmp + | (index << TSHIFT_INDEX); + assert(index < TMAX_INDEX); + assert((bmp & ~TMASK_BMP) == 0); + return (node_t){ .i = i, .p = twigs }; +} + +/*! \brief Make an empty root node. */ +static node_t empty_root(void) +{ + return mkbranch(TMAX_INDEX-1, 0, NULL); +} + +/*! \brief Propagate error codes. */ +#define ERR_RETURN(x) \ + do { \ + int err_code_ = x; \ + if (unlikely(err_code_ != KNOT_EOK)) \ + return err_code_; \ + } while (false) + + +/*! \brief Test flags to determine type of this node. */ +static bool isbranch(const node_t *t) +{ + return t->i & TFLAG_BRANCH; +} + +static tkey_t *tkey(const node_t *t) +{ + assert(!isbranch(t)); + return (tkey_t *)(uintptr_t)(t->i & TMASK_LEAF); +} + +static trie_val_t *tvalp(node_t *t) +{ + assert(!isbranch(t)); + return &t->p; +} + +/*! \brief Given a branch node, return the index of the corresponding nibble in the key. */ +static index_t branch_index(const node_t *t) +{ + assert(isbranch(t)); + return (t->i & TMASK_INDEX) >> TSHIFT_INDEX; +} + +static bitmap_t branch_bmp(const node_t *t) +{ + assert(isbranch(t)); + return (t->i & TMASK_BMP); +} + +/*! + * \brief Count the number of set bits. + * + * \TODO This implementation may be relatively slow on some HW. + */ +static uint branch_weight(const node_t *t) +{ + assert(isbranch(t)); + uint n = __builtin_popcount(t->i & TMASK_BMP); + assert(n > 1 && n <= TWIDTH_BMP); + return n; +} + +/*! \brief Compute offset of an existing child in a branch node. */ +static uint twigoff(const node_t *t, bitmap_t bit) +{ + assert(isbranch(t)); + assert(__builtin_popcount(bit) == 1); + return __builtin_popcount(t->i & TMASK_BMP & (bit - 1)); +} + +/*! \brief Extract a nibble from a key and turn it into a bitmask. */ +static bitmap_t keybit(index_t ni, const trie_key_t *key, uint32_t len) +{ + index_t bytei = ni >> 1; + + if (bytei >= len) + return BMP_NOBYTE; + + uint8_t ki = (uint8_t)key[bytei]; + uint nibble = (ni & 1) ? (ki & 0xf) : (ki >> 4); + + // skip one for NOBYTE nibbles after the end of the key + return BIG1 << (nibble + 1 + TSHIFT_BMP); +} + +/*! \brief Extract a nibble from a key and turn it into a bitmask. */ +static bitmap_t twigbit(const node_t *t, const trie_key_t *key, uint32_t len) +{ + assert(isbranch(t)); + return keybit(branch_index(t), key, len); +} + +/*! \brief Test if a branch node has a child indicated by a bitmask. */ +static bool hastwig(const node_t *t, bitmap_t bit) +{ + assert(isbranch(t)); + assert((bit & ~TMASK_BMP) == 0); + assert(__builtin_popcount(bit) == 1); + return t->i & bit; +} + +/*! \brief Get pointer to packed array of child nodes. */ +static node_t* twigs(node_t *t) +{ + assert(isbranch(t)); + return t->p; +} + +/*! \brief Get pointer to a particular child of a branch node. */ +static node_t* twig(node_t *t, uint i) +{ + assert(i < branch_weight(t)); + return twigs(t) + i; +} + +/*! \brief Get twig number of a child node TODO: better description. */ +static uint twig_number(node_t *child, node_t *parent) +{ + // twig array index using pointer arithmetic + ptrdiff_t num = child - twigs(parent); + assert(num >= 0 && num < branch_weight(parent)); + return (uint)num; +} + +/*! \brief Simple string comparator. */ +static int key_cmp(const trie_key_t *k1, uint32_t k1_len, + const trie_key_t *k2, uint32_t k2_len) +{ + int ret = memcmp(k1, k2, MIN(k1_len, k2_len)); + if (ret != 0) { + return ret; + } + + /* Key string is equal, compare lengths. */ + if (k1_len == k2_len) { + return 0; + } else if (k1_len < k2_len) { + return -1; + } else { + return 1; + } +} + +trie_t* trie_create(knot_mm_t *mm) +{ + trie_t *trie = mm_alloc(mm, sizeof(trie_t)); + if (trie != NULL) { + trie->root = empty_root(); + trie->weight = 0; + if (mm != NULL) + trie->mm = *mm; + else + mm_ctx_init(&trie->mm); + } + return trie; +} + +/*! \brief Free anything under the trie node, except for the passed pointer itself. */ +static void clear_trie(node_t *trie, knot_mm_t *mm) +{ + if (!isbranch(trie)) { + mm_free(mm, tkey(trie)); + } else { + uint n = branch_weight(trie); + for (uint i = 0; i < n; ++i) + clear_trie(twig(trie, i), mm); + mm_free(mm, twigs(trie)); + } +} + +void trie_free(trie_t *tbl) +{ + if (tbl == NULL) + return; + if (tbl->weight) + clear_trie(&tbl->root, &tbl->mm); + mm_free(&tbl->mm, tbl); +} + +void trie_clear(trie_t *tbl) +{ + assert(tbl); + if (!tbl->weight) + return; + clear_trie(&tbl->root, &tbl->mm); + tbl->root = empty_root(); + tbl->weight = 0; +} + +static bool dup_trie(node_t *copy, const node_t *orig, trie_dup_cb dup_cb, knot_mm_t *mm) +{ + if (isbranch(orig)) { + uint n = branch_weight(orig); + node_t *cotw = mm_alloc(mm, n * sizeof(*cotw)); + if (cotw == NULL) { + return NULL; + } + const node_t *ortw = twigs((node_t *)orig); + for (uint i = 0; i < n; ++i) { + if (!dup_trie(cotw + i, ortw + i, dup_cb, mm)) { + while (i-- > 0) { + clear_trie(cotw + i, mm); + } + mm_free(mm, cotw); + return false; + } + } + *copy = mkbranch(branch_index(orig), branch_bmp(orig), cotw); + } else { + tkey_t *key = tkey(orig); + if (mkleaf(copy, key->chars, key->len, mm) != KNOT_EOK) { + return false; + } + if ((copy->p = dup_cb(orig->p, mm)) == NULL) { + mm_free(mm, tkey(copy)); + return false; + } + } + return true; +} + +trie_t* trie_dup(const trie_t *orig, trie_dup_cb dup_cb, knot_mm_t *mm) +{ + if (orig == NULL) { + return NULL; + } + trie_t *copy = mm_alloc(mm, sizeof(*copy)); + if (copy == NULL) { + return NULL; + } + copy->weight = orig->weight; + if (mm != NULL) { + copy->mm = *mm; + } else { + mm_ctx_init(©->mm); + } + if (copy->weight) { + if (!dup_trie(©->root, &orig->root, dup_cb, mm)) { + mm_free(mm, copy); + return NULL; + } + } + return copy; +} + +size_t trie_weight(const trie_t *tbl) +{ + assert(tbl); + return tbl->weight; +} + +trie_val_t* trie_get_try(trie_t *tbl, const trie_key_t *key, uint32_t len) +{ + assert(tbl); + if (!tbl->weight) + return NULL; + node_t *t = &tbl->root; + while (isbranch(t)) { + __builtin_prefetch(twigs(t)); + bitmap_t b = twigbit(t, key, len); + if (!hastwig(t, b)) + return NULL; + t = twig(t, twigoff(t, b)); + } + tkey_t *lkey = tkey(t); + if (key_cmp(key, len, lkey->chars, lkey->len) != 0) + return NULL; + return tvalp(t); +} + +/* Optimization: the approach isn't ideal, as e.g. walking through the prefix + * is duplicated and we explicitly construct the wildcard key. Still, it's close + * to optimum which would be significantly more complicated and error-prone to write. */ +trie_val_t* trie_get_try_wildcard(trie_t *tbl, const trie_key_t *key, uint32_t len) +{ + assert(tbl); + if (!tbl->weight) + return NULL; + // Find leaf sharing the longest common prefix; see ns_find_branch() for explanation. + node_t *t = &tbl->root; + while (isbranch(t)) { + __builtin_prefetch(twigs(t)); + bitmap_t b = twigbit(t, key, len); + uint i = hastwig(t, b) ? twigoff(t, b) : 0; + t = twig(t, i); + } + const tkey_t * const lcp_key = tkey(t); + + // Find the last matching zero byte or -1 (source of synthesis) + int i_lmz = -1; + for (int i = 0; i < len && i < lcp_key->len && key[i] == lcp_key->chars[i]; ++i) { + if (key[i] == '\0' && i < len - 1) // do not count the terminating zero + i_lmz = i; + // Shortcut: we may have found an exact match. + if (i == len - 1 && len == lcp_key->len) + return tvalp(t); + } + if (len == 0) // The empty name needs separate handling. + return lcp_key->len == 0 ? tvalp(t) : NULL; + + // Construct the key of the wildcard we need and look it up. + const int wild_len = i_lmz + 3; + uint8_t wild_key[wild_len]; + memcpy(wild_key, key, wild_len - 2); + wild_key[wild_len - 2] = '*'; + wild_key[wild_len - 1] = '\0'; // LF is always 0-terminated ATM + return trie_get_try(tbl, wild_key, wild_len); +} + +/*! \brief Delete leaf t with parent p; b is the bit for t under p. + * Optionally return the deleted value via val. The function can't fail. */ +static void del_found(trie_t *tbl, node_t *t, node_t *p, bitmap_t b, trie_val_t *val) +{ + assert(!tkey(t)->cow); + mm_free(&tbl->mm, tkey(t)); + if (val != NULL) + *val = *tvalp(t); // we return trie_val_t directly when deleting + --tbl->weight; + if (unlikely(!p)) { // whole trie was a single leaf + assert(tbl->weight == 0); + tbl->root = empty_root(); + return; + } + // remove leaf t as child of p + node_t *tp = twigs(p); + uint ci = twig_number(t, p); + uint cc = branch_weight(p); // child count + + if (cc == 2) { + // collapse binary node p: move the other child to the parent + *p = tp[1 - ci]; + mm_free(&tbl->mm, tp); + return; + } + memmove(tp + ci, tp + ci + 1, sizeof(node_t) * (cc - ci - 1)); + p->i &= ~b; + node_t *newt = mm_realloc(&tbl->mm, tp, sizeof(node_t) * (cc - 1), + sizeof(node_t) * cc); + if (likely(newt != NULL)) + p->p = newt; + // We can ignore mm_realloc failure because an oversized twig + // array is OK - only beware that next time the prev_size + // passed to mm_realloc will not be correct; TODO? +} + +int trie_del(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t *val) +{ + assert(tbl); + if (!tbl->weight) + return KNOT_ENOENT; + node_t *t = &tbl->root; // current and parent node + node_t *p = NULL; + bitmap_t b = 0; + while (isbranch(t)) { + __builtin_prefetch(twigs(t)); + b = twigbit(t, key, len); + if (!hastwig(t, b)) + return KNOT_ENOENT; + p = t; + t = twig(t, twigoff(t, b)); + } + tkey_t *lkey = tkey(t); + if (key_cmp(key, len, lkey->chars, lkey->len) != 0) + return KNOT_ENOENT; + del_found(tbl, t, p, b, val); + return KNOT_EOK; +} + +/*! + * \brief Stack of nodes, storing a path down a trie. + * + * The structure also serves directly as the public trie_it_t type, + * in which case it always points to the current leaf, unless we've finished + * (i.e. it->len == 0). + * stack[0] is always a valid pointer to the root -> ns_gettrie() + */ +typedef struct trie_it { + node_t* *stack; /*!< The stack; malloc is used directly instead of mm. */ + uint32_t len; /*!< Current length of the stack. */ + uint32_t alen; /*!< Allocated/available length of the stack. */ + /*! \brief Initial storage for \a stack; it should fit in most use cases. */ + node_t* stack_init[250]; +} nstack_t; + +/*! \brief Create a node stack containing just the root (or empty). */ +static void ns_init(nstack_t *ns, trie_t *tbl) +{ + assert(tbl); + ns->stack = ns->stack_init; + ns->alen = sizeof(ns->stack_init) / sizeof(ns->stack_init[0]); + ns->stack[0] = &tbl->root; + ns->len = (tbl->weight > 0); +} + +static inline trie_t * ns_gettrie(nstack_t *ns) +{ + assert(ns && ns->stack && ns->stack[0]); + return (struct trie *)ns->stack[0]; +} + +/*! \brief Free inside of the stack, i.e. not the passed pointer itself. */ +static void ns_cleanup(nstack_t *ns) +{ + assert(ns && ns->stack); + if (likely(ns->stack == ns->stack_init)) + return; + free(ns->stack); + #ifndef NDEBUG + ns->stack = NULL; + ns->alen = 0; + #endif +} + +/*! \brief Allocate more space for the stack. */ +static int ns_longer_alloc(nstack_t *ns) +{ + ns->alen *= 2; + size_t new_size = ns->alen * sizeof(node_t *); + node_t **st; + if (ns->stack == ns->stack_init) { + st = malloc(new_size); + if (st != NULL) + memcpy(st, ns->stack, ns->len * sizeof(node_t *)); + } else { + st = realloc(ns->stack, new_size); + } + if (st == NULL) + return KNOT_ENOMEM; + ns->stack = st; + return KNOT_EOK; +} + +/*! \brief Ensure the node stack can be extended by one. */ +static inline int ns_longer(nstack_t *ns) +{ + // get a longer stack if needed + if (likely(ns->len < ns->alen)) + return KNOT_EOK; + return ns_longer_alloc(ns); // hand-split the part suitable for inlining +} + +/*! + * \brief Find the "branching point" as if searching for a key. + * + * The whole path to the point is kept on the passed stack; + * always at least the root will remain on the top of it. + * Beware: the precise semantics of this function is rather tricky. + * The top of the stack will contain: the corresponding leaf if exact + * match is found; or the immediate node below a + * branching-point-on-edge or the branching-point itself. + * + * \param idiff Set the index of first differing nibble, or TMAX_INDEX for an exact match + * \param tbit Set the bit of the closest leaf's nibble at index idiff + * \param kbit Set the bit of the key's nibble at index idiff + * + * \return KNOT_EOK or KNOT_ENOMEM. + */ +static int ns_find_branch(nstack_t *ns, const trie_key_t *key, uint32_t len, + index_t *idiff, bitmap_t *tbit, bitmap_t *kbit) +{ + assert(ns && ns->len && idiff); + // First find some leaf with longest matching prefix. + while (isbranch(ns->stack[ns->len - 1])) { + ERR_RETURN(ns_longer(ns)); + node_t *t = ns->stack[ns->len - 1]; + __builtin_prefetch(twigs(t)); + bitmap_t b = twigbit(t, key, len); + // Even if our key is missing from this branch we need to + // keep iterating down to a leaf. It doesn't matter which + // twig we choose since the keys are all the same up to this + // index. Note that blindly using twigoff(t, b) can cause + // an out-of-bounds index if it equals twigmax(t). + uint i = hastwig(t, b) ? twigoff(t, b) : 0; + ns->stack[ns->len++] = twig(t, i); + } + tkey_t *lkey = tkey(ns->stack[ns->len-1]); + // Find index of the first char that differs. + size_t bytei = 0; + uint32_t klen = lkey->len; + for (bytei = 0; bytei < MIN(len,klen); bytei++) { + if (key[bytei] != lkey->chars[bytei]) + break; + } + // Find which half-byte has matched. + index_t index = bytei << 1; + if (bytei == len && len == lkey->len) { // found equivalent key + index = TMAX_INDEX; + goto success; + } + if (likely(bytei < MIN(len,klen))) { + uint8_t k2 = (uint8_t)lkey->chars[bytei]; + uint8_t k1 = (uint8_t)key[bytei]; + if (((k1 ^ k2) & 0xf0) == 0) + index += 1; + } + // now go up the trie from the current leaf + node_t *t; + do { + if (unlikely(ns->len == 1)) + goto success; // only the root stays on the stack + t = ns->stack[ns->len - 2]; + if (branch_index(t) < index) + goto success; + --ns->len; + } while (true); +success: + #ifndef NDEBUG // invariants on successful return + assert(ns->len); + if (isbranch(ns->stack[ns->len - 1])) { + t = ns->stack[ns->len - 1]; + assert(branch_index(t) >= index); + } + if (ns->len > 1) { + t = ns->stack[ns->len - 2]; + assert(branch_index(t) < index || index == TMAX_INDEX); + } + #endif + *idiff = index; + *tbit = keybit(index, lkey->chars, lkey->len); + *kbit = keybit(index, key, len); + return KNOT_EOK; +} + +/*! + * \brief Advance the node stack to the last leaf in the subtree. + * + * \return KNOT_EOK or KNOT_ENOMEM. + */ +static int ns_last_leaf(nstack_t *ns) +{ + assert(ns); + do { + ERR_RETURN(ns_longer(ns)); + node_t *t = ns->stack[ns->len - 1]; + if (!isbranch(t)) + return KNOT_EOK; + uint lasti = branch_weight(t) - 1; + ns->stack[ns->len++] = twig(t, lasti); + } while (true); +} + +/*! + * \brief Advance the node stack to the first leaf in the subtree. + * + * \return KNOT_EOK or KNOT_ENOMEM. + */ +static int ns_first_leaf(nstack_t *ns) +{ + assert(ns && ns->len); + do { + ERR_RETURN(ns_longer(ns)); + node_t *t = ns->stack[ns->len - 1]; + if (!isbranch(t)) + return KNOT_EOK; + ns->stack[ns->len++] = twig(t, 0); + } while (true); +} + +/*! + * \brief Advance the node stack to the leaf that is previous to the current node. + * + * \note Prefix leaf under the current node DOES count (if present; perhaps questionable). + * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM. + */ +static int ns_prev_leaf(nstack_t *ns) +{ + assert(ns && ns->len > 0); + + node_t *t = ns->stack[ns->len - 1]; + // Beware: BMP_NOBYTE child is ordered *before* its parent. + if (isbranch(t) && hastwig(t, BMP_NOBYTE)) { + ERR_RETURN(ns_longer(ns)); + ns->stack[ns->len++] = twig(t, 0); + return KNOT_EOK; + } + + for (; ns->len >= 2; --ns->len) { + t = ns->stack[ns->len - 1]; + node_t *p = ns->stack[ns->len - 2]; + uint ci = twig_number(t, p); + if (ci == 0) // we've got to go up again + continue; + // t isn't the first child -> go down the previous one + ns->stack[ns->len - 1] = twig(p, ci - 1); + return ns_last_leaf(ns); + } + return KNOT_ENOENT; // root without empty key has no previous leaf +} + +/*! + * \brief Advance the node stack to the leaf that is successor to the current node. + * + * \param skip_prefixed skip any nodes whose key is a prefix of the current one. + * If false, prefix leaf or anything else under the current node DOES count. + * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM. + */ +static int ns_next_leaf(nstack_t *ns, const bool skip_pefixed) +{ + assert(ns && ns->len > 0); + + node_t *t = ns->stack[ns->len - 1]; + if (!skip_pefixed && isbranch(t)) + return ns_first_leaf(ns); + for (; ns->len >= 2; --ns->len) { + t = ns->stack[ns->len - 1]; + node_t *p = ns->stack[ns->len - 2]; + uint ci = twig_number(t, p); + if (skip_pefixed && ci == 0 && hastwig(t, BMP_NOBYTE)) { + // Keys in the subtree of p are suffixes of the key of t, + // so we've got to go one level higher + // (this can't happen more than once) + continue; + } + uint cc = branch_weight(p); + assert(ci + 1 <= cc); + if (ci + 1 == cc) { + // t is the last child of p, so we need to keep climbing + continue; + } + // go down the next child of p + ns->stack[ns->len - 1] = twig(p, ci + 1); + return ns_first_leaf(ns); + } + return KNOT_ENOENT; // not found, as no more parent is available +} + +/*! \brief Advance the node stack to leaf with longest prefix of the current key. */ +static int ns_prefix(nstack_t *ns) +{ + assert(ns && ns->len > 0); + const node_t *start = ns->stack[ns->len - 1]; + // Walk up the trie until we find a BMP_NOBYTE child. + while (--ns->len > 0) { + node_t *p = ns->stack[ns->len - 1]; + if (!hastwig(p, BMP_NOBYTE)) + continue; + node_t *end = twig(p, 0); + // In case we started in a BMP_NOBYTE leaf, the first step up + // did NOT shorten the key and we would get back into the same + // node again. + if (end == start) + continue; + ns->stack[ns->len++] = end; + return KNOT_EOK; + } + return KNOT_ENOENT; // not found, as no more parent is available +} + +/*! \brief less-or-equal search. + * + * \return KNOT_EOK for exact match, 1 for previous, KNOT_ENOENT for not-found, + * or KNOT_E*. + */ +static int ns_get_leq(nstack_t *ns, const trie_key_t *key, uint32_t len) +{ + // First find the key with longest-matching prefix + index_t idiff; + bitmap_t tbit, kbit; + ERR_RETURN(ns_find_branch(ns, key, len, &idiff, &tbit, &kbit)); + node_t *t = ns->stack[ns->len - 1]; + if (idiff == TMAX_INDEX) // found exact match + return KNOT_EOK; + // Get t: the last node on matching path + bitmap_t b; + if (isbranch(t) && branch_index(t) == idiff) { + // t is OK + b = kbit; + } else { + // the top of the stack was the first unmatched node -> step up + if (ns->len == 1) { + // root was unmatched already + if (kbit < tbit) + return KNOT_ENOENT; + ERR_RETURN(ns_last_leaf(ns)); + return 1; + } + --ns->len; + t = ns->stack[ns->len - 1]; + b = twigbit(t, key, len); + } + // Now we re-do the first "non-matching" step in the trie + // but try the previous child if key was less (it may not exist) + int i = hastwig(t, b) + ? (int)twigoff(t, b) - (kbit < tbit) + : (int)twigoff(t, b) - 1 /* twigoff returns successor when !hastwig */; + if (i >= 0) { + ERR_RETURN(ns_longer(ns)); + ns->stack[ns->len++] = twig(t, i); + ERR_RETURN(ns_last_leaf(ns)); + } else { + ERR_RETURN(ns_prev_leaf(ns)); + } + return 1; +} + +int trie_get_leq(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t **val) +{ + assert(tbl && val); + if (tbl->weight == 0) { + if (val) *val = NULL; + return KNOT_ENOENT; + } + // We try to do without malloc. + nstack_t ns_local; + ns_init(&ns_local, tbl); + nstack_t *ns = &ns_local; + + int ret = ns_get_leq(ns, key, len); + if (ret == KNOT_EOK || ret == 1) { + assert(!isbranch(ns->stack[ns->len - 1])); + if (val) *val = tvalp(ns->stack[ns->len - 1]); + } else { + if (val) *val = NULL; + } + ns_cleanup(ns); + return ret; +} + +int trie_it_get_leq(trie_it_t *it, const trie_key_t *key, uint32_t len) +{ + assert(it && it->stack[0] && it->alen); + const trie_t *tbl = ns_gettrie(it); + if (tbl->weight == 0) { + it->len = 0; + return KNOT_ENOENT; + } + it->len = 1; + int ret = ns_get_leq(it, key, len); + if (ret == KNOT_EOK || ret == 1) { + assert(trie_it_key(it, NULL)); + } else { + it->len = 0; + } + return ret; +} + +/* see below */ +static int cow_pushdown(trie_cow_t *cow, nstack_t *ns); + +/*! \brief implementation of trie_get_ins() and trie_get_cow() */ +static trie_val_t* cow_get_ins(trie_cow_t *cow, trie_t *tbl, + const trie_key_t *key, uint32_t len) +{ + assert(tbl); + // First leaf in an empty tbl? + if (unlikely(!tbl->weight)) { + if (unlikely(mkleaf(&tbl->root, key, len, &tbl->mm))) + return NULL; + ++tbl->weight; + return tvalp(&tbl->root); + } + { // Intentionally un-indented; until end of function, to bound cleanup attr. + // Find the branching-point + __attribute__((cleanup(ns_cleanup))) + nstack_t ns_local; + ns_init(&ns_local, tbl); + nstack_t *ns = &ns_local; + index_t idiff; + bitmap_t tbit, kbit; + if (unlikely(ns_find_branch(ns, key, len, &idiff, &tbit, &kbit))) + return NULL; + if (unlikely(cow && cow_pushdown(cow, ns) != KNOT_EOK)) + return NULL; + node_t *t = ns->stack[ns->len - 1]; + if (idiff == TMAX_INDEX) // the same key was already present + return tvalp(t); + node_t leaf, *leafp; + if (unlikely(mkleaf(&leaf, key, len, &tbl->mm))) + return NULL; + + if (isbranch(t) && branch_index(t) == idiff) { + // The node t needs a new leaf child. + assert(!hastwig(t, kbit)); + // new child position and original child count + uint s = twigoff(t, kbit); + uint m = branch_weight(t); + node_t *nt = mm_realloc(&tbl->mm, twigs(t), + sizeof(node_t) * (m + 1), sizeof(node_t) * m); + if (unlikely(!nt)) + goto err_leaf; + memmove(nt + s + 1, nt + s, sizeof(node_t) * (m - s)); + leafp = nt + s; + *t = mkbranch(idiff, branch_bmp(t) | kbit, nt); + } else { + // We need to insert a new binary branch with leaf at *t. + // Note: it works the same for the case where we insert above root t. + #ifndef NDEBUG + if (ns->len > 1) { + node_t *pt = ns->stack[ns->len - 2]; + assert(hastwig(pt, twigbit(pt, key, len))); + } + #endif + node_t *nt = mm_alloc(&tbl->mm, sizeof(node_t) * 2); + if (unlikely(!nt)) + goto err_leaf; + node_t t2 = *t; // Save before overwriting t. + *t = mkbranch(idiff, tbit | kbit, nt); + *twig(t, twigoff(t, tbit)) = t2; + leafp = twig(t, twigoff(t, kbit)); + }; + *leafp = leaf; + ++tbl->weight; + return tvalp(leafp); +err_leaf: + mm_free(&tbl->mm, tkey(&leaf)); + return NULL; + } +} + +trie_val_t* trie_get_ins(trie_t *tbl, const trie_key_t *key, uint32_t len) +{ + return cow_get_ins(NULL, tbl, key, len); +} + +/*! \brief Apply a function to every trie_val_t*, in order; a recursive solution. */ +static int apply_nodes(node_t *t, int (*f)(trie_val_t *, void *), void *d) +{ + assert(t); + if (!isbranch(t)) + return f(tvalp(t), d); + uint n = branch_weight(t); + for (uint i = 0; i < n; ++i) + ERR_RETURN(apply_nodes(twig(t, i), f, d)); + return KNOT_EOK; +} + +int trie_apply(trie_t *tbl, int (*f)(trie_val_t *, void *), void *d) +{ + assert(tbl && f); + if (!tbl->weight) + return KNOT_EOK; + return apply_nodes(&tbl->root, f, d); +} + +/* These are all thin wrappers around static Tns* functions. */ +trie_it_t* trie_it_begin(trie_t *tbl) +{ + assert(tbl); + trie_it_t *it = malloc(sizeof(nstack_t)); + if (!it) + return NULL; + ns_init(it, tbl); + if (it->len == 0) // empty tbl + return it; + if (ns_first_leaf(it)) { + ns_cleanup(it); + free(it); + return NULL; + } + return it; +} + +bool trie_it_finished(trie_it_t *it) +{ + assert(it); + return it->len == 0; +} + +void trie_it_free(trie_it_t *it) +{ + if (!it) + return; + ns_cleanup(it); + free(it); +} + +trie_it_t *trie_it_clone(const trie_it_t *it) +{ + if (!it) // TODO: or should that be an assertion? + return NULL; + trie_it_t *it2 = malloc(sizeof(nstack_t)); + if (!it2) + return NULL; + it2->len = it->len; + it2->alen = it->alen; // we _might_ change it in the rare malloc case, but... + if (likely(it->stack == it->stack_init)) { + it2->stack = it2->stack_init; + assert(it->alen == sizeof(it->stack_init) / sizeof(it->stack_init[0])); + } else { + it2->stack = malloc(it2->alen * sizeof(it2->stack[0])); + if (!it2->stack) { + free(it2); + return NULL; + } + } + memcpy(it2->stack, it->stack, it->len * sizeof(it->stack[0])); + return it2; +} + +const trie_key_t* trie_it_key(trie_it_t *it, size_t *len) +{ + assert(it && it->len); + node_t *t = it->stack[it->len - 1]; + assert(!isbranch(t)); + tkey_t *key = tkey(t); + if (len) + *len = key->len; + return key->chars; +} + +trie_val_t* trie_it_val(trie_it_t *it) +{ + assert(it && it->len); + node_t *t = it->stack[it->len - 1]; + assert(!isbranch(t)); + return tvalp(t); +} + +void trie_it_next(trie_it_t *it) +{ + assert(it && it->len); + if (ns_next_leaf(it, false) != KNOT_EOK) + it->len = 0; +} + +void trie_it_next_loop(trie_it_t *it) +{ + assert(it && it->len); + int ret = ns_next_leaf(it, false); + if (ret == KNOT_ENOENT) { + it->len = 1; + ret = ns_first_leaf(it); + } + if (ret) + it->len = 0; +} + +void trie_it_next_nosuffix(trie_it_t *it) +{ + assert(it && it->len); + if (ns_next_leaf(it, true) != KNOT_EOK) + it->len = 0; +} + +void trie_it_prev(trie_it_t *it) +{ + assert(it && it->len); + if (ns_prev_leaf(it) != KNOT_EOK) + it->len = 0; +} + +void trie_it_prev_loop(trie_it_t *it) +{ + assert(it && it->len); + int ret = ns_prev_leaf(it); + if (ret == KNOT_ENOENT) { + it->len = 1; + ret = ns_last_leaf(it); + } + if (ret) + it->len = 0; +} + +void trie_it_parent(trie_it_t *it) +{ + assert(it && it->len); + if (ns_prefix(it)) + it->len = 0; +} + +void trie_it_del(trie_it_t *it) +{ + assert(it && it->len); + if (it->len == 0) + return; + node_t *t = it->stack[it->len - 1]; + assert(!isbranch(t)); + bitmap_t b; // del_found() needs to know which bit to zero in the bitmap + node_t *p; + if (it->len == 1) { // deleting the root + p = NULL; + b = 0; // unused + } else { + p = it->stack[it->len - 2]; + assert(isbranch(p)); + size_t len; + const trie_key_t *key = trie_it_key(it, &len); + b = twigbit(p, key, len); + } + // We could trie_it_{next,prev,...}(it) now, in case we wanted that semantics. + it->len = 0; + del_found(ns_gettrie(it), t, p, b, NULL); +} + + +/*!\file + * + * \section About copy-on-write + * + * In these notes I'll use the term "object" to refer to either the + * twig array of a branch, or the application's data that is referred + * to by a leaf's trie_val_t pointer. Note that for COW we don't care + * about trie node_t structs themselves, but the objects that they + * point to. + * + * \subsection COW states + * + * During a COW transaction an object can be in one of three states: + * shared, only in the old trie, or only in the new trie. When a + * transaction is rolled back, the only-new objects are freed; when a + * transaction is committed the new trie takes the place of the old + * one and only-old objects are freed. + * + * \subsection branch marks and regions + * + * A branch object can be marked by setting the COW flag in the first + * element of its twig array. Marked branches partition the trie into + * regions; an object's state depends on its region. + * + * The unmarked branch objects between a trie's root and the marked + * branches (excluding the marked branches themselves) is exclusively + * owned: either old-only (if you started from the old root) or + * new-only (if you started from the new root). + * + * Marked branch objects, and all objects reachable from marked branch + * objects, are in the shared region accessible from both old and new + * roots. All branch objects below a marked branch must be unmarked. + * (That is, there is at most one marked branch object on any path + * from the root of a trie.) + * + * Branch nodes in the new-only region can be modified in place, in + * the same way as an original qp trie. Branch nodes in the old-only + * or shared regions must not be modified. + * + * \subsection app object states + * + * The app objects reachable from the new-only and old-only regions + * explicitly record their state in a way determined by the + * application. (These app objects are reachable from the old and new + * roots by traversing only unmarked branch objects.) + * + * The app objects reachable from marked branch objects are implicitly + * shared, but their state field has an indeterminate value. If an app + * object was previously touched by a rolled-back transaction it may + * be marked shared or old-only; if it was previously touched by a + * committed transaction it may be marked shared or new-only. + * + * \subsection key states + * + * The memory allocated for tkey_t objects also needs to track its + * sharing state. They have a "cow" flag to mark when they are shared. + * Keys are relatively lazily copied (to make them exclusive) when + * their leaf node is touched by a COW mutation. + * + * [An alternative technique might be to copy them more eagerly, in + * cow_pushdown(), which would avoid the need for a flag bit at the + * cost of more allocator churn in a transaction.] + * + * \subsection outside COW + * + * When a COW transaction is not in progress, there are no marked + * branch objects, so everything is exclusively owned. When a COW + * transaction is finished (committed or rolled back), the branch + * marks are removed. Since they are in the shared region, this branch + * cleanup is visible to both old and new tries. + * + * However the state of app objects is not clean between COW + * transactions. When a COW transaction is committed, we traverse the + * old-only region to find old-only app objects that should be freed + * (and vice versa for rollback). In general, there will be app + * objects that are only reachable from the new-only region, and that + * have a mixture of shared and new states. + */ + +/*! \brief Trie copy-on-write state */ +struct trie_cow { + trie_t *old; + trie_t *new; + trie_cb *mark_shared; + void *d; +}; + +/*! \brief is this a marked branch object */ +static bool cow_marked(node_t *t) +{ + return isbranch(t) && (twigs(t)->i & TFLAG_COW); +} + +/*! \brief is this a leaf with a marked key */ +static bool cow_key(node_t *t) +{ + return !isbranch(t) && tkey(t)->cow; +} + +/*! \brief remove mark from a branch object */ +static void clear_cow(node_t *t) +{ + assert(isbranch(t)); + twigs(t)->i &= ~TFLAG_COW; +} + +/*! \brief mark a node as shared + * + * For branches this marks the twig array (in COW terminology, the + * branch object); for leaves it uses the callback to mark the app + * object. + */ +static void mark_cow(trie_cow_t *cow, node_t *t) +{ + if (isbranch(t)) { + node_t *object = twigs(t); + object->i |= TFLAG_COW; + } else { + tkey_t *lkey = tkey(t); + lkey->cow = 1; + if (cow->mark_shared != NULL) { + trie_val_t *valp = tvalp(t); + cow->mark_shared(*valp, lkey->chars, lkey->len, cow->d); + } + } +} + +/*! \brief push exclusive COW region down one node */ +static int cow_pushdown_one(trie_cow_t *cow, node_t *t) +{ + uint cc = branch_weight(t); + node_t *nt = mm_alloc(&cow->new->mm, sizeof(node_t) * cc); + if (nt == NULL) + return KNOT_ENOMEM; + /* mark all the children */ + for (uint ci = 0; ci < cc; ++ci) + mark_cow(cow, twig(t, ci)); + /* this node must be unmarked in both old and new versions */ + clear_cow(t); + t->p = memcpy(nt, twigs(t), sizeof(node_t) * cc); + return KNOT_EOK; +} + +/*! \brief push exclusive COW region to cover a whole node stack */ +static int cow_pushdown(trie_cow_t *cow, nstack_t *ns) +{ + node_t *new_twigs = NULL; + node_t *old_twigs = NULL; + for (uint i = 0; i < ns->len; i++) { + /* if we did a pushdown on the previous iteration, we + need to update this stack entry so it points into + the parent's new twigs instead of the old ones */ + if (new_twigs != old_twigs) + ns->stack[i] = new_twigs + (ns->stack[i] - old_twigs); + if (cow_marked(ns->stack[i])) { + old_twigs = twigs(ns->stack[i]); + if (cow_pushdown_one(cow, ns->stack[i])) + return KNOT_ENOMEM; + new_twigs = twigs(ns->stack[i]); + } else { + new_twigs = NULL; + old_twigs = NULL; + /* ensure key is exclusively owned */ + if (cow_key(ns->stack[i])) { + node_t oleaf = *ns->stack[i]; + tkey_t *okey = tkey(&oleaf); + if(mkleaf(ns->stack[i], okey->chars, okey->len, + &cow->new->mm)) + return KNOT_ENOMEM; + ns->stack[i]->p = oleaf.p; + okey->cow = 0; + } + } + } + return KNOT_EOK; +} + +trie_cow_t* trie_cow(trie_t *old, trie_cb *mark_shared, void *d) +{ + knot_mm_t *mm = &old->mm; + trie_t *new = mm_alloc(mm, sizeof(trie_t)); + trie_cow_t *cow = mm_alloc(mm, sizeof(trie_cow_t)); + if (new == NULL || cow == NULL) { + mm_free(mm, new); + mm_free(mm, cow); + return NULL; + } + new->mm = old->mm; + new->root = old->root; + new->weight = old->weight; + cow->old = old; + cow->new = new; + cow->mark_shared = mark_shared; + cow->d = d; + if (old->weight) + mark_cow(cow, &old->root); + return cow; +} + +trie_t* trie_cow_new(trie_cow_t *cow) +{ + assert(cow != NULL); + return cow->new; +} + +trie_val_t* trie_get_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len) +{ + return cow_get_ins(cow, cow->new, key, len); +} + +int trie_del_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len, trie_val_t *val) +{ + trie_t *tbl = cow->new; + if (unlikely(!tbl->weight)) + return KNOT_ENOENT; + { // Intentionally un-indented; until end of function, to bound cleanup attr. + // Find the branching-point + __attribute__((cleanup(ns_cleanup))) + nstack_t ns_local; + ns_init(&ns_local, tbl); + nstack_t *ns = &ns_local; + index_t idiff; + bitmap_t tbit, kbit; + ERR_RETURN(ns_find_branch(ns, key, len, &idiff, &tbit, &kbit)); + if (idiff != TMAX_INDEX) + return KNOT_ENOENT; + ERR_RETURN(cow_pushdown(cow, ns)); + node_t *t = ns->stack[ns->len - 1]; + node_t *p = ns->len >= 2 ? ns->stack[ns->len - 2] : NULL; + del_found(tbl, t, p, p ? twigbit(p, key, len) : 0, val); + } + return KNOT_EOK; +} + +/*! \brief clean up after a COW transaction, recursively */ +static void cow_cleanup(trie_cow_t *cow, node_t *t, trie_cb *cb, void *d) +{ + if (cow_marked(t)) { + // we have hit the shared region, so just reset the mark + clear_cow(t); + return; + } else if (isbranch(t)) { + // traverse and free the exclusive region + uint cc = branch_weight(t); + for (uint ci = 0; ci < cc; ++ci) + cow_cleanup(cow, twig(t, ci), cb, d); + mm_free(&cow->new->mm, twigs(t)); + return; + } else { + // application must decide how to clean up its values + tkey_t *lkey = tkey(t); + if (cb != NULL) { + trie_val_t *valp = tvalp(t); + cb(*valp, lkey->chars, lkey->len, d); + } + // clean up exclusively-owned keys + if (lkey->cow) + lkey->cow = 0; + else + mm_free(&cow->new->mm, lkey); + return; + } +} + +trie_t* trie_cow_commit(trie_cow_t *cow, trie_cb *cb, void *d) +{ + trie_t *ret = cow->new; + if (cow->old->weight) + cow_cleanup(cow, &cow->old->root, cb, d); + mm_free(&ret->mm, cow->old); + mm_free(&ret->mm, cow); + return ret; +} + +trie_t* trie_cow_rollback(trie_cow_t *cow, trie_cb *cb, void *d) +{ + trie_t *ret = cow->old; + if (cow->new->weight) + cow_cleanup(cow, &cow->new->root, cb, d); + mm_free(&ret->mm, cow->new); + mm_free(&ret->mm, cow); + return ret; +} diff --git a/utils/kresctl/deps/qp-trie/trie.h b/utils/kresctl/deps/qp-trie/trie.h new file mode 100644 index 000000000..14fb1a2d0 --- /dev/null +++ b/utils/kresctl/deps/qp-trie/trie.h @@ -0,0 +1,280 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + Copyright (C) 2018 Tony Finch + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include +#include + +#include "libknot/mm_ctx.h" + +/*! + * \brief Native API of QP-tries: + * + * - keys are uint8_t strings, not necessarily zero-terminated, + * the structure copies the contents of the passed keys + * - values are void* pointers, typically you get an ephemeral pointer to it + * - key lengths are limited by 2^32-1 ATM + */ + +/*! \brief Element value. */ +typedef void* trie_val_t; +/*! \brief Key for indexing tries. Sign could be flipped easily. */ +typedef uint8_t trie_key_t; + +/*! \brief Opaque structure holding a QP-trie. */ +typedef struct trie trie_t; + +/*! \brief Opaque type for holding a QP-trie iterator. */ +typedef struct trie_it trie_it_t; + +/*! \brief Callback for cloning trie values. */ +typedef trie_val_t (*trie_dup_cb)(const trie_val_t val, knot_mm_t *mm); + +/*! \brief Callback for performing actions on a trie leaf + * + * Used during copy-on-write transactions + * + * \param val The value of the element to be altered + * \param key The key of the element to be altered + * \param len The length of key + * \param d Additional user data + */ +typedef void trie_cb(trie_val_t val, const trie_key_t *key, size_t len, void *d); + +/*! \brief Opaque type for holding the copy-on-write state for a QP-trie. */ +typedef struct trie_cow trie_cow_t; + +/*! \brief Create a trie instance. */ +trie_t* trie_create(knot_mm_t *mm); + +/*! \brief Free a trie instance. */ +void trie_free(trie_t *tbl); + +/*! \brief Clear a trie instance (make it empty). */ +void trie_clear(trie_t *tbl); + +/*! \brief Create a clone of existing trie. */ +trie_t* trie_dup(const trie_t *orig, trie_dup_cb dup_cb, knot_mm_t *mm); + +/*! \brief Return the number of keys in the trie. */ +size_t trie_weight(const trie_t *tbl); + +/*! \brief Search the trie, returning NULL on failure. */ +trie_val_t* trie_get_try(trie_t *tbl, const trie_key_t *key, uint32_t len); + +/*! \brief Search the trie including DNS wildcard semantics, returning NULL on failure. + * + * \note We assume the key is in knot_dname_lf() format, i.e. labels are ordered + * from root to leaf and separated by zero bytes (and no other zeros are allowed). + * \note Beware that DNS wildcard matching is not exactly what normal people would expect. + */ +trie_val_t* trie_get_try_wildcard(trie_t *tbl, const trie_key_t *key, uint32_t len); + +/*! \brief Search the trie, inserting NULL trie_val_t on failure. */ +trie_val_t* trie_get_ins(trie_t *tbl, const trie_key_t *key, uint32_t len); + +/*! + * \brief Search for less-or-equal element. + * + * \param tbl Trie. + * \param key Searched key. + * \param len Key length. + * \param val (optional) Value found; it will be set to NULL if not found or errored. + * \return KNOT_EOK for exact match, 1 for previous, KNOT_ENOENT for not-found, + * or KNOT_E*. + */ +int trie_get_leq(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t **val); + +/*! + * \brief Apply a function to every trie_val_t, in order. + * + * \return KNOT_EOK if success or KNOT_E* if error. + */ +int trie_apply(trie_t *tbl, int (*f)(trie_val_t *, void *), void *d); + +/*! + * \brief Remove an item, returning KNOT_EOK if succeeded or KNOT_ENOENT if not found. + * + * If val!=NULL and deletion succeeded, the deleted value is set. + */ +int trie_del(trie_t *tbl, const trie_key_t *key, uint32_t len, trie_val_t *val); + + +/*! \brief Create a new iterator pointing to the first element (if any). + * + * trie_it_* functions deal with these iterators capable of walking and jumping + * over the trie. Note that any modification to key-set stored by the trie + * will in general invalidate all iterators and you will need to begin anew. + * (It won't be detected - you may end up reading freed memory, etc.) + */ +trie_it_t* trie_it_begin(trie_t *tbl); + +/*! \brief Test if the iterator has gone "past the end" (and points nowhere). */ +bool trie_it_finished(trie_it_t *it); + +/*! \brief Free any resources of the iterator. It's OK to call it on NULL. */ +void trie_it_free(trie_it_t *it); + +/*! \brief Copy the iterator. See the warning in trie_it_begin(). */ +trie_it_t *trie_it_clone(const trie_it_t *it); + +/*! + * \brief Return pointer to the key of the current element. + * + * \note The len is uint32_t internally but size_t is better for our usage + * as it is without an additional type conversion. + */ +const trie_key_t* trie_it_key(trie_it_t *it, size_t *len); + +/*! \brief Return pointer to the value of the current element (writable). */ +trie_val_t* trie_it_val(trie_it_t *it); + +/*! + * \brief Advance the iterator to the next element. + * + * Iteration is in ascending lexicographical order. + * In particular, the empty string would be considered as the very first. + * + * \TODO: in most iterator operations, ENOMEM is very unlikely + * but it leads to a _finished() iterator (silently). + * Perhaps the functions should simply return KNOT_E* + */ +void trie_it_next(trie_it_t *it); +/*! \brief Advance the iterator to the previous element. See trie_it_next(). */ +void trie_it_prev(trie_it_t *it); + +/*! \brief Advance iterator to the next element, looping to first after last. */ +void trie_it_next_loop(trie_it_t *it); +/*! \brief Advance iterator to the previous element, looping to last after first. */ +void trie_it_prev_loop(trie_it_t *it); + +/*! \brief Advance iterator to the next element while ignoring the subtree. + * + * \note Another formulation: skip keys that are prefixed by the current key. + * \TODO: name, maybe _noprefixed? The thing is that in the "subtree" meaning + * doesn't correspond to how the pointers go in the implementation, + * but we may not care much for implementation in the API... + */ +void trie_it_next_nosub(trie_it_t *it); + +/*! \brief Advance iterator to the longest prefix of the current key. + * + * \TODO: name, maybe _prefix? Arguments similar to _nosub vs. _noprefixed. + */ +void trie_it_parent(trie_it_t *it); + +/*! \brief trie_get_leq() but with an iterator. */ +int trie_it_get_leq(trie_it_t *it, const trie_key_t *key, uint32_t len); + +/*! \brief Remove the current element. The iterator will get trie_it_finished() */ +void trie_it_del(trie_it_t *it); + + +/*! \brief Start a COW transaction + * + * A copy-on-write transaction starts by obtaining a write lock (in + * your application code) followed by a call to trie_cow(). This + * creates a shared clone of the trie and saves both old and new roots + * in the COW context. + * + * During the COW transaction, you call trie_cow_ins() or + * trie_cow_del() as necessary. These calls ensure that the relevant + * parts of the (new) trie are copied so that they can be modified + * freely. + * + * Your trie_val_t objects must be able to distinguish their + * reachability, either shared, or old-only, or new-only. Before a COW + * transaction the reachability of your objects is indeterminate. + * During a transaction, any trie_val_t objects that might be affected + * (because they are adjacent to a trie_get_cow() or trie_del_cow()) + * are first marked as shared using the callback you pass to + * trie_cow(). + * + * When the transaction is complete, to commit, call trie_cow_new() to + * get the new root, swap the old and new trie roots (e.g. with + * rcu_xchg_pointer()), wait for readers to finish with the old trie + * (e.g. using synchronize_rcu()), then call trie_cow_commit(). For a + * rollback, you can just call trie_cow_rollback() without waiting + * since that doesn't conflict with readers. After trie_cow_commit() + * or trie_cow_rollback() have finished, you can release your write + * lock. + * + * Concurrent reading of the old trie is allowed during a transaction + * provided that it is known when all readers have finished with the + * old version, e.g. using rcu_read_lock() and rcu_read_unlock(). + * There must be only one write transaction at a time. + * + * \param old the old trie + * \param mark_shared callback to mark a leaf as shared (can be NULL) + * \param d extra data for the callback + * \return a pointer to a COW context, + * or NULL if there was a failure + */ +trie_cow_t* trie_cow(trie_t *old, trie_cb *mark_shared, void *d); + +/*! \brief get the new trie from a COW context */ +trie_t* trie_cow_new(trie_cow_t *cow); + +/*! \brief variant of trie_get_ins() for use during COW transactions + * + * As necessary, this copies path from the root of the trie to the + * leaf, so that it is no longer shared. Any leaves adjacent to this + * path are marked as shared using the mark_shared callback passed to + * trie_cow(). + * + * It is your responsibility to COW your trie_val_t objects. If you copy an + * object you must change the original's reachability from shared to old-only. + * New objects (including copies) must have new-only reachability. + */ +trie_val_t* trie_get_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len); + +/*! + * \brief variant of trie_del() for use during COW transactions + * + * The mark_shared callback is invoked as necessary, in the same way + * as trie_get_cow(). + * + * Returns KNOT_EOK if the key was removed or KNOT_ENOENT if not found. + * If val!=NULL and deletion succeeded, the *val is set to the deleted + * value pointer. + */ +int trie_del_cow(trie_cow_t *cow, const trie_key_t *key, uint32_t len, trie_val_t *val); + +/*! \brief clean up the old trie after committing a COW transaction + * + * Your callback is invoked for any trie_val_t objects that might need + * cleaning up; you must free any objects you have marked as old-only + * and retain objects with shared reachability. + * + * \note The callback can be NULL. + * + * The cow object is free()d, and the new trie root is returned. + */ +trie_t* trie_cow_commit(trie_cow_t *cow, trie_cb *cb, void *d); + +/*! \brief clean up the new trie after rolling back a COW transaction + * + * Your callback is invoked for any trie_val_t objects that might need + * cleaning up; you must free any objects you have marked as new-only + * and retain objects with shared reachability. + * + * \note The callback can be NULL. + * + * The cow object is free()d, and the old trie root is returned. + */ +trie_t* trie_cow_rollback(trie_cow_t *cow, trie_cb *cb, void *d); diff --git a/utils/kresctl/deps/string.c b/utils/kresctl/deps/string.c new file mode 100644 index 000000000..e8ff3267e --- /dev/null +++ b/utils/kresctl/deps/string.c @@ -0,0 +1,215 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include +#include +#include +#include +#include +#if defined(HAVE_EXPLICIT_BZERO) + #if defined(HAVE_BSD_STRING_H) + #include + #endif + /* #include is needed. */ +#elif defined(HAVE_EXPLICIT_MEMSET) + /* #include is needed. */ +#elif defined(HAVE_GNUTLS_MEMSET) + #include +#else + #define USE_CUSTOM_MEMSET +#endif + +#include "string.h" +#include "ctype.h" + +uint8_t *memdup(const uint8_t *data, size_t data_size) +{ + uint8_t *result = (uint8_t *)malloc(data_size); + if (!result) { + return NULL; + } + + return memcpy(result, data, data_size); +} + +char *sprintf_alloc(const char *fmt, ...) +{ + char *strp = NULL; + va_list ap; + + va_start(ap, fmt); + int ret = vasprintf(&strp, fmt, ap); + va_end(ap); + + if (ret < 0) { + return NULL; + } + return strp; +} + +char *strcdup(const char *s1, const char *s2) +{ + if (!s1 || !s2) { + return NULL; + } + + size_t s1len = strlen(s1); + size_t s2len = strlen(s2); + size_t nlen = s1len + s2len + 1; + + char* dst = malloc(nlen); + if (dst == NULL) { + return NULL; + } + + memcpy(dst, s1, s1len); + memcpy(dst + s1len, s2, s2len + 1); + return dst; +} + +char *strstrip(const char *str) +{ + // leading white-spaces + const char *scan = str; + while (is_space(scan[0])) { + scan += 1; + } + + // trailing white-spaces + size_t len = strlen(scan); + while (len > 0 && is_space(scan[len - 1])) { + len -= 1; + } + + char *trimmed = malloc(len + 1); + if (!trimmed) { + return NULL; + } + + memcpy(trimmed, scan, len); + trimmed[len] = '\0'; + + return trimmed; +} + +int const_time_memcmp(const void *s1, const void *s2, size_t n) +{ + volatile uint8_t equal = 0; + + for (size_t i = 0; i < n; i++) { + equal |= ((uint8_t *)s1)[i] ^ ((uint8_t *)s2)[i]; + } + + return equal; +} + +#if defined(USE_CUSTOM_MEMSET) +typedef void *(*memset_t)(void *, int, size_t); +static volatile memset_t volatile_memset = memset; +#endif + +void *memzero(void *s, size_t n) +{ +#if defined(HAVE_EXPLICIT_BZERO) /* In OpenBSD since 5.5. */ + /* In FreeBSD since 11.0. */ + /* In glibc since 2.25. */ + /* In DragonFly BSD since 5.5. */ + explicit_bzero(s, n); + return s; +#elif defined(HAVE_EXPLICIT_MEMSET) /* In NetBSD since 7.0. */ + return explicit_memset(s, 0, n); +#elif defined(HAVE_GNUTLS_MEMSET) /* In GnuTLS since 3.4.0. */ + gnutls_memset(s, 0, n); + return s; +#else /* Knot custom solution as a fallback. */ + /* Warning: the use of the return value is *probably* needed + * so as to avoid the volatile_memset() to be optimized out. + */ + return volatile_memset(s, 0, n); +#endif +} + +static const char BIN_TO_HEX[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' +}; + +char *bin_to_hex(const uint8_t *bin, size_t bin_len) +{ + if (bin == NULL) { + return NULL; + } + + size_t hex_size = bin_len * 2; + char *hex = malloc(hex_size + 1); + if (hex == NULL) { + return NULL; + } + + for (size_t i = 0; i < bin_len; i++) { + hex[2 * i] = BIN_TO_HEX[bin[i] >> 4]; + hex[2 * i + 1] = BIN_TO_HEX[bin[i] & 0x0f]; + } + hex[hex_size] = '\0'; + + return hex; +} + +/*! + * Convert HEX character to numeric value (assumes valid input). + */ +static uint8_t hex_to_number(const char hex) +{ + if (hex >= '0' && hex <= '9') { + return hex - '0'; + } else if (hex >= 'a' && hex <= 'f') { + return hex - 'a' + 10; + } else { + assert(hex >= 'A' && hex <= 'F'); + return hex - 'A' + 10; + } +} + +uint8_t *hex_to_bin(const char *hex, size_t *out_len) +{ + if (hex == NULL || out_len == NULL) { + return NULL; + } + + size_t hex_len = strlen(hex); + if (hex_len % 2 != 0) { + return NULL; + } + + size_t bin_len = hex_len / 2; + uint8_t *bin = malloc(bin_len + 1); + if (bin == NULL) { + return NULL; + } + + for (size_t i = 0; i < bin_len; i++) { + if (!is_xdigit(hex[2 * i]) || !is_xdigit(hex[2 * i + 1])) { + free(bin); + return NULL; + } + uint8_t high = hex_to_number(hex[2 * i]); + uint8_t low = hex_to_number(hex[2 * i + 1]); + bin[i] = high << 4 | low; + } + + *out_len = bin_len; + + return bin; +} diff --git a/utils/kresctl/deps/string.h b/utils/kresctl/deps/string.h new file mode 100644 index 000000000..205c699a8 --- /dev/null +++ b/utils/kresctl/deps/string.h @@ -0,0 +1,93 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +/*! + * \brief String manipulations. + */ + +#pragma once + +#include +#include + +/*! + * \brief Create a copy of a binary buffer. + * + * Like \c strdup, but for binary data. + */ +uint8_t *memdup(const uint8_t *data, size_t data_size); + +/*! + * \brief Format string and take care of allocating memory. + * + * \note sprintf(3) manual page reference implementation. + * + * \param fmt Message format. + * \return formatted message or NULL. + */ +char *sprintf_alloc(const char *fmt, ...); + +/*! + * \brief Create new string from a concatenation of s1 and s2. + * + * \param s1 First string. + * \param s2 Second string. + * + * \retval Newly allocated string on success. + * \retval NULL on error. + */ +char *strcdup(const char *s1, const char *s2); + +/*! + * \brief Create a copy of a string skipping leading and trailing white spaces. + * + * \return Newly allocated string, NULL in case of error. + */ +char *strstrip(const char *str); + +/*! + * \brief Compare data in time based on string length. + * This function just checks for (in)equality not for relation + * + * \param s1 The first address to compare. + * \param s2 The second address to compare. + * \param n The size of memory to compare. + * + * \return Non zero on difference and zero if the buffers are identical. + */ +int const_time_memcmp(const void *s1, const void *s2, size_t n); + +/*! + * \brief Fill memory with zeroes. + * + * Inspired by OPENSSL_cleanse. Such a memset shouldn't be optimized out. + * + * \param s The address to fill. + * \param n The size of memory to fill. + * + * \return Pointer to the memory. + */ +void *memzero(void *s, size_t n); + +/*! + * \brief Convert binary data to hexadecimal string. + */ +char *bin_to_hex(const uint8_t *bin, size_t bin_len); + +/*! + * \brief Convert hex encoded string to binary data. + */ +uint8_t *hex_to_bin(const char *hex, size_t *out_len); diff --git a/utils/kresctl/deps/ucw/LICENSE b/utils/kresctl/deps/ucw/LICENSE new file mode 100644 index 000000000..b463d5709 --- /dev/null +++ b/utils/kresctl/deps/ucw/LICENSE @@ -0,0 +1 @@ +../licenses/LGPL-2.0 \ No newline at end of file diff --git a/utils/kresctl/deps/ucw/array-sort.h b/utils/kresctl/deps/ucw/array-sort.h new file mode 100644 index 000000000..1ff137717 --- /dev/null +++ b/utils/kresctl/deps/ucw/array-sort.h @@ -0,0 +1,195 @@ +/* + * UCW Library -- Universal Simple Array Sorter + * + * (c) 2003--2008 Martin Mares + * + * This software may be freely distributed and used according to the terms + * of the GNU Lesser General Public License. + */ + +#pragma once + +#include "contrib/macros.h" + +/* + * This is not a normal header file, it's a generator of sorting + * routines. Each time you include it with parameters set in the + * corresponding preprocessor macros, it generates an array sorter + * with the parameters given. + * + * You might wonder why the heck do we implement our own array sorter + * instead of using qsort(). The primary reason is that qsort handles + * only continuous arrays, but we need to sort array-like data structures + * where the only way to access elements is by using an indexing macro. + * Besides that, we are more than 2 times faster. + * + * So much for advocacy, there are the parameters (those marked with [*] + * are mandatory): + * + * ASORT_PREFIX(x) [*] add a name prefix (used on all global names + * defined by the sorter) + * ASORT_KEY_TYPE [*] data type of a single array entry key + * ASORT_ELT(i) returns the key of i-th element; if this macro is not + * defined, the function gets a pointer to an array to be sorted + * ASORT_LT(x,y) x < y for ASORT_KEY_TYPE (default: "x= ASORT_THRESHOLD && (right - l) >= ASORT_THRESHOLD) + { + /* Both partitions ok => push the larger one */ + if ((r - left) > (right - l)) + { + stack[sp].l = left; + stack[sp].r = r; + left = l; + } + else + { + stack[sp].l = l; + stack[sp].r = right; + right = r; + } + sp++; + } + else if ((r - left) >= ASORT_THRESHOLD) + { + /* Left partition OK, right undersize */ + right = r; + } + else if ((right - l) >= ASORT_THRESHOLD) + { + /* Right partition OK, left undersize */ + left = l; + } + else + { + /* Both partitions undersize => pop */ + if (!sp) + break; + sp--; + left = stack[sp].l; + right = stack[sp].r; + } + } + + /* + * We have a partially sorted array, finish by insertsort. Inspired + * by qsort() in GNU libc. + */ + + /* Find minimal element which will serve as a barrier */ + r = MIN(array_size, ASORT_THRESHOLD); + m = 0; + for (l=1; l + * + * This software may be freely distributed and used according to the terms + * of the GNU Lesser General Public License. + */ + +#pragma once + +/*** + * [[defs]] + * Definitions + * ----------- + ***/ + +/** + * Find the first element not lower than \p x in the sorted array \p ary of \p N elements (non-decreasing order). + * Returns the index of the found element or \p N if no exists. Uses `ary_lt_x(ary,i,x)` to compare the i'th element with \p x. + * The time complexity is `O(log(N))`. + **/ +#define BIN_SEARCH_FIRST_GE_CMP(ary, N, ary_lt_x, x, ...) ({ \ + unsigned l = 0, r = (N); \ + while (l < r) \ + { \ + unsigned m = (l+r)/2; \ + if (ary_lt_x(ary, m, x, __VA_ARGS__)) \ + l = m+1; \ + else \ + r = m; \ + } \ + l; \ +}) + +/** + * The default comparison macro for \ref BIN_SEARCH_FIRST_GE_CMP(). + **/ +#define ARY_LT_NUM(ary,i,x) (ary)[i] < (x) + +/** + * Same as \ref BIN_SEARCH_FIRST_GE_CMP(), but uses the default `<` operator for comparisons. + **/ +#define BIN_SEARCH_FIRST_GE(ary,N,x) BIN_SEARCH_FIRST_GE_CMP(ary,N,ARY_LT_NUM,x) + +/** + * Search the sorted array \p ary of \p N elements (non-decreasing) for the first occurrence of \p x. + * Returns the index or -1 if no such element exists. Uses the `<` operator for comparisons. + **/ +#define BIN_SEARCH_EQ(ary,N,x) ({ int i = BIN_SEARCH_FIRST_GE(ary,N,x); if (i >= (N) || (ary)[i] != (x)) i=-1; i; }) diff --git a/utils/kresctl/deps/ucw/heap.c b/utils/kresctl/deps/ucw/heap.c new file mode 100644 index 000000000..d7ed18e08 --- /dev/null +++ b/utils/kresctl/deps/ucw/heap.c @@ -0,0 +1,166 @@ +/* + * Binary heap + * + * (c) 2012 Ondrej Filip + * + * This software may be freely distributed and used according to the terms + * of the GNU Lesser General Public License. + */ + +/*** + * Introduction + * ------------ + * + * Binary heap is a simple data structure, which for example supports efficient insertions, deletions + * and access to the minimal inserted item. We define several macros for such operations. + * Note that because of simplicity of heaps, we have decided to define direct macros instead + * of a <> as for several other data structures in the Libucw. + * + * A heap is represented by a number of elements and by an array of values. Beware that we + * index this array from one, not from zero as do the standard C arrays. + * + * Most macros use these parameters: + * + * - @num - a variable (signed or unsigned integer) with the number of elements + * - @heap - a C array of type @type; the heap is stored in `heap[1] .. heap[num]`; `heap[0]` is unused + * + * A valid heap must follow these rules: + * + * - `num >= 0` + * - `heap[i] >= heap[i / 2]` for each `i` in `[2, num]` + * + * The first element `heap[1]` is always lower or equal to all other elements. + ***/ + +#include +#include +#include "contrib/ucw/heap.h" + +static inline void heap_swap(heap_val_t **e1, heap_val_t **e2) +{ + if (e1 == e2) return; /* Stack tmp should be faster than tmpelem. */ + heap_val_t *tmp = *e1; /* Even faster than 2-XOR nowadays. */ + *e1 = *e2; + *e2 = tmp; + int pos = (*e1)->pos; + (*e1)->pos = (*e2)->pos; + (*e2)->pos = pos; +} + +int heap_init(struct heap *h, int (*cmp)(void *, void *), int init_size) +{ + int isize = init_size ? init_size : INITIAL_HEAP_SIZE; + + h->num = 0; + h->max_size = isize; + h->cmp = cmp; + h->data = malloc((isize + 1) * sizeof(heap_val_t*)); /* Temp element unused. */ + + return h->data ? 1 : 0; +} + +void heap_deinit(struct heap *h) +{ + free(h->data); + memset(h, 0, sizeof(*h)); +} + +static inline void _heap_bubble_down(struct heap *h, int e) +{ + int e1; + for (;;) + { + e1 = 2*e; + if(e1 > h->num) break; + if((h->cmp(*HELEMENT(h, e),*HELEMENT(h,e1)) < 0) && (e1 == h->num || (h->cmp(*HELEMENT(h, e),*HELEMENT(h,e1+1)) < 0))) break; + if((e1 != h->num) && (h->cmp(*HELEMENT(h, e1+1), *HELEMENT(h,e1)) < 0)) e1++; + heap_swap(HELEMENT(h,e),HELEMENT(h,e1)); + e = e1; + } +} + +static inline void _heap_bubble_up(struct heap *h, int e) +{ + int e1; + while (e > 1) + { + e1 = e/2; + if(h->cmp(*HELEMENT(h, e1),*HELEMENT(h,e)) < 0) break; + heap_swap(HELEMENT(h,e),HELEMENT(h,e1)); + e = e1; + } + +} + +static void heap_increase(struct heap *h, int pos, heap_val_t *e) +{ + *HELEMENT(h, pos) = e; + e->pos = pos; + _heap_bubble_down(h, pos); +} + +static void heap_decrease(struct heap *h, int pos, heap_val_t *e) +{ + *HELEMENT(h, pos) = e; + e->pos = pos; + _heap_bubble_up(h, pos); +} + +void heap_replace(struct heap *h, int pos, heap_val_t *e) +{ + if (h->cmp(*HELEMENT(h, pos),e) < 0) { + heap_increase(h, pos, e); + } else { + heap_decrease(h, pos, e); + } +} + +void heap_delmin(struct heap *h) +{ + if(h->num == 0) return; + if(h->num > 1) + { + heap_swap(HHEAD(h),HELEMENT(h,h->num)); + } + (*HELEMENT(h, h->num))->pos = 0; + --h->num; + _heap_bubble_down(h, 1); +} + +int heap_insert(struct heap *h, heap_val_t *e) +{ + if(h->num == h->max_size) + { + h->max_size = h->max_size * HEAP_INCREASE_STEP; + h->data = realloc(h->data, (h->max_size + 1) * sizeof(heap_val_t*)); + if (!h->data) { + return 0; + } + } + + h->num++; + *HELEMENT(h,h->num) = e; + e->pos = h->num; + _heap_bubble_up(h,h->num); + return 1; +} + +int heap_find(struct heap *h, heap_val_t *elm) +{ + return ((struct heap_val *) elm)->pos; +} + +void heap_delete(struct heap *h, int e) +{ + heap_swap(HELEMENT(h, e), HELEMENT(h, h->num)); + (*HELEMENT(h, h->num))->pos = 0; + h->num--; + if(h->cmp(*HELEMENT(h, e), *HELEMENT(h, h->num + 1)) < 0) _heap_bubble_up(h, e); + else _heap_bubble_down(h, e); + + if ((h->num > INITIAL_HEAP_SIZE) && (h->num < h->max_size / HEAP_DECREASE_THRESHOLD)) + { + h->max_size = h->max_size / HEAP_INCREASE_STEP; + h->data = realloc(h->data, (h->max_size + 1) * sizeof(heap_val_t*)); + } +} diff --git a/utils/kresctl/deps/ucw/heap.h b/utils/kresctl/deps/ucw/heap.h new file mode 100644 index 000000000..58958b361 --- /dev/null +++ b/utils/kresctl/deps/ucw/heap.h @@ -0,0 +1,46 @@ +/* Copyright (C) 2011 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +struct heap_val { + int pos; +}; + +typedef struct heap_val heap_val_t; + +struct heap { + int num; /* Number of elements */ + int max_size; /* Size of allocated memory */ + int (*cmp)(void *, void *); + heap_val_t **data; +}; /* Array follows */ + +#define INITIAL_HEAP_SIZE 512 /* initial heap size */ +#define HEAP_INCREASE_STEP 2 /* multiplier for each inflation, keep conservative */ +#define HEAP_DECREASE_THRESHOLD 2 /* threshold for deflation, keep conservative */ +#define HELEMENT(h,num) ((h)->data + (num)) +#define HHEAD(h) HELEMENT((h), 1) +#define EMPTY_HEAP(h) ((h)->num == 0) /* h->num == 0 */ + +int heap_init(struct heap *, int (*cmp)(void *, void *), int); +void heap_deinit(struct heap *); + +void heap_delmin(struct heap *); +int heap_insert(struct heap *, heap_val_t *); +int heap_find(struct heap *, heap_val_t *); +void heap_delete(struct heap *, int); +void heap_replace(struct heap *h, int pos, heap_val_t *); diff --git a/utils/kresctl/deps/ucw/lists.c b/utils/kresctl/deps/ucw/lists.c new file mode 100644 index 000000000..b3f254686 --- /dev/null +++ b/utils/kresctl/deps/ucw/lists.c @@ -0,0 +1,235 @@ +/* + * BIRD Library -- Linked Lists + * + * (c) 1998 Martin Mares + * (c) 2015, 2019 CZ.NIC, z.s.p.o. + * + * Can be freely distributed and used under the terms of the GNU GPL. + */ + +/** + * DOC: Linked lists + * + * The BIRD library provides a set of functions for operating on linked + * lists. The lists are internally represented as standard doubly linked + * lists with synthetic head and tail which makes all the basic operations + * run in constant time and contain no extra end-of-list checks. Each list + * is described by a &list structure, nodes can have any format as long + * as they start with a &node structure. If you want your nodes to belong + * to multiple lists at once, you can embed multiple &node structures in them + * and use the SKIP_BACK() macro to calculate a pointer to the start of the + * structure from a &node pointer, but beware of obscurity. + * + * There also exist safe linked lists (&slist, &snode and all functions + * being prefixed with |s_|) which support asynchronous walking very + * similar to that used in the &fib structure. + */ + +#include +#include +#include "contrib/ucw/lists.h" +#include "contrib/mempattern.h" + +/** + * add_tail - append a node to a list + * \p l: linked list + * \p n: list node + * + * add_tail() takes a node \p n and appends it at the end of the list \p l. + */ +void +add_tail(list_t *l, node_t *n) +{ + node_t *z = l->tail; + + n->next = (node_t *) &l->null; + n->prev = z; + z->next = n; + l->tail = n; +} + +/** + * add_head - prepend a node to a list + * \p l: linked list + * \p n: list node + * + * add_head() takes a node \p n and prepends it at the start of the list \p l. + */ +void +add_head(list_t *l, node_t *n) +{ + node_t *z = l->head; + + n->next = z; + n->prev = (node_t *) &l->head; + z->prev = n; + l->head = n; +} + +/** + * insert_node - insert a node to a list + * \p n: a new list node + * \p after: a node of a list + * + * Inserts a node \p n to a linked list after an already inserted + * node \p after. + */ +void +insert_node(node_t *n, node_t *after) +{ + node_t *z = after->next; + + n->next = z; + n->prev = after; + after->next = n; + z->prev = n; +} + +/** + * rem_node - remove a node from a list + * \p n: node to be removed + * + * Removes a node \p n from the list it's linked in. + */ +void +rem_node(node_t *n) +{ + node_t *z = n->prev; + node_t *x = n->next; + + z->next = x; + x->prev = z; + n->prev = 0; + n->next = 0; +} + +/** + * init_list - create an empty list + * \p l: list + * + * init_list() takes a &list structure and initializes its + * fields, so that it represents an empty list. + */ +void +init_list(list_t *l) +{ + l->head = (node_t *) &l->null; + l->null = NULL; + l->tail = (node_t *) &l->head; +} + +/** + * add_tail_list - concatenate two lists + * \p to: destination list + * \p l: source list + * + * This function appends all elements of the list \p l to + * the list \p to in constant time. + */ +void +add_tail_list(list_t *to, list_t *l) +{ + node_t *p = to->tail; + node_t *q = l->head; + + p->next = q; + q->prev = p; + q = l->tail; + q->next = (node_t *) &to->null; + to->tail = q; +} + +/** + * list_dup - duplicate list + * \p to: destination list + * \p l: source list + * + * This function duplicates all elements of the list \p l to + * the list \p to in linear time. + * + * This function only works with a homogenous item size. + */ +void list_dup(list_t *dst, list_t *src, size_t itemsz) +{ + node_t *n; + WALK_LIST(n, *src) { + node_t *i = malloc(itemsz); + memcpy(i, n, itemsz); + add_tail(dst, i); + } +} + +/** + * list_size - gets number of nodes + * \p l: list + * + * This function counts nodes in list \p l and returns this number. + */ +size_t list_size(const list_t *l) +{ + size_t count = 0; + + node_t *n; + WALK_LIST(n, *l) { + count++; + } + + return count; +} + +/** + * ptrlist_add - add pointer to pointer list + * \p to: destination list + * \p val: added pointer + * \p mm: memory context + */ +ptrnode_t *ptrlist_add(list_t *to, void *val, knot_mm_t *mm) +{ + ptrnode_t *node = mm_alloc(mm , sizeof(ptrnode_t)); + if (node == NULL) { + return NULL; + } else { + node->d = val; + } + add_tail(to, &node->n); + return node; +} + +/** + * ptrlist_free - free all nodes in pointer list + * \p list: list nodes + * \p mm: memory context + */ +void ptrlist_free(list_t *list, knot_mm_t *mm) +{ + node_t *n, *nxt; + WALK_LIST_DELSAFE(n, nxt, *list) { + mm_free(mm, n); + } + init_list(list); +} + +/** + * ptrlist_rem - remove pointer node + * \p val: pointer to remove + * \p mm: memory context + */ +void ptrlist_rem(ptrnode_t *node, knot_mm_t *mm) +{ + rem_node(&node->n); + mm_free(mm, node); +} + +/** + * ptrlist_deep_free - free all nodes incl referenced data + * \p list: list nodes + * \p mm: memory context + */ +void ptrlist_deep_free(list_t *l, knot_mm_t *mm) +{ + ptrnode_t *n; + WALK_LIST(n, *l) { + mm_free(mm, n->d); + } + ptrlist_free(l, mm); +} diff --git a/utils/kresctl/deps/ucw/lists.h b/utils/kresctl/deps/ucw/lists.h new file mode 100644 index 000000000..922e152f4 --- /dev/null +++ b/utils/kresctl/deps/ucw/lists.h @@ -0,0 +1,84 @@ +/* + * BIRD Library -- Linked Lists + * + * (c) 1998 Martin Mares + * (c) 2015, 2017 CZ.NIC, z.s.p.o. + * + * Can be freely distributed and used under the terms of the GNU GPL. + */ + +#pragma once + +/* + * I admit the list structure is very tricky and also somewhat awkward, + * but it's both efficient and easy to manipulate once one understands the + * basic trick: The list head always contains two synthetic nodes which are + * always present in the list: the head and the tail. But as the `next' + * entry of the tail and the `prev' entry of the head are both NULL, the + * nodes can overlap each other: + * + * head head_node.next + * null head_node.prev tail_node.next + * tail tail_node.prev + */ + +#include +#include "libknot/mm_ctx.h" + +typedef struct node { + struct node *next, *prev; +} node_t; + +typedef struct list { /* In fact two overlayed nodes */ + struct node *head, *null, *tail; +} list_t; + +#define NODE (node_t *) +#define HEAD(list) ((void *)((list).head)) +#define TAIL(list) ((void *)((list).tail)) +#define WALK_LIST(n,list) for(n=HEAD(list);(NODE (n))->next; \ + n=(void *)((NODE (n))->next)) +#define WALK_LIST_DELSAFE(n,nxt,list) \ + for(n=HEAD(list); (nxt=(void *)((NODE (n))->next)); n=(void *) nxt) +/* WALK_LIST_FIRST supposes that called code removes each processed node */ +#define WALK_LIST_FIRST(n,list) \ + while(n=HEAD(list), (NODE (n))->next) +#define WALK_LIST_BACKWARDS(n,list) for(n=TAIL(list);(NODE (n))->prev; \ + n=(void *)((NODE (n))->prev)) +#define WALK_LIST_BACKWARDS_DELSAFE(n,prv,list) \ + for(n=TAIL(list); prv=(void *)((NODE (n))->prev); n=(void *) prv) + +#define EMPTY_LIST(list) (!(list).head->next) + +/*! \brief Free every node in the list. */ +#define WALK_LIST_FREE(list) \ + do { \ + node_t *n=0,*nxt=0; \ + WALK_LIST_DELSAFE(n,nxt,list) { \ + free(n); \ + } \ + init_list(&list); \ + } while(0) + +void add_tail(list_t *, node_t *); +void add_head(list_t *, node_t *); +void rem_node(node_t *); +void add_tail_list(list_t *, list_t *); +void init_list(list_t *); +void insert_node(node_t *, node_t *); +void list_dup(list_t *dst, list_t *src, size_t itemsz); +size_t list_size(const list_t *); + +/*! + * \brief Generic pointer list implementation. + */ +typedef struct ptrnode { + node_t n; + void *d; +} ptrnode_t; + +ptrnode_t *ptrlist_add(list_t *, void *, knot_mm_t *); +void ptrlist_free(list_t *, knot_mm_t *); +void ptrlist_rem(ptrnode_t *node, knot_mm_t *mm); +void ptrlist_deep_free(list_t *, knot_mm_t *); + diff --git a/utils/kresctl/deps/ucw/mempool.c b/utils/kresctl/deps/ucw/mempool.c new file mode 100644 index 000000000..8e835c117 --- /dev/null +++ b/utils/kresctl/deps/ucw/mempool.c @@ -0,0 +1,323 @@ +/* + * UCW Library -- Memory Pools (One-Time Allocation) + * + * (c) 1997--2001 Martin Mares + * (c) 2007 Pavel Charvat + * (c) 2015, 2017 CZ.NIC, z.s.p.o. + * + * This software may be freely distributed and used according to the terms + * of the GNU Lesser General Public License. + */ + +#undef LOCAL_DEBUG + +#include +#include +#include +#include +#include +#include "contrib/asan.h" +#include "contrib/macros.h" +#include "contrib/ucw/mempool.h" + +/** \todo This shouldn't be precalculated, but computed on load. */ +#define CPU_PAGE_SIZE 4096 + +/** Align an integer \p s to the nearest higher multiple of \p a (which should be a power of two) **/ +#define ALIGN_TO(s, a) (((s)+a-1)&~(a-1)) +#define MP_CHUNK_TAIL ALIGN_TO(sizeof(struct mempool_chunk), CPU_STRUCT_ALIGN) +#define MP_SIZE_MAX (~0U - MP_CHUNK_TAIL - CPU_PAGE_SIZE) +#define DBG(s, ...) + +/** \note Imported MMAP backend from bigalloc.c */ +#define CONFIG_UCW_POOL_IS_MMAP +#ifdef CONFIG_UCW_POOL_IS_MMAP +#include +static void * +page_alloc(uint64_t len) +{ + if (!len) { + return NULL; + } + if (len > SIZE_MAX) { + return NULL; + } + assert(!(len & (CPU_PAGE_SIZE-1))); + uint8_t *p = mmap(NULL, len, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); + if (p == (uint8_t*) MAP_FAILED) { + return NULL; + } + return p; +} + +static void +page_free(void *start, uint64_t len) +{ + assert(!(len & (CPU_PAGE_SIZE-1))); + assert(!((uintptr_t) start & (CPU_PAGE_SIZE-1))); + munmap(start, len); +} +#endif + +struct mempool_chunk { + struct mempool_chunk *next; + unsigned size; +}; + +static unsigned +mp_align_size(unsigned size) +{ +#ifdef CONFIG_UCW_POOL_IS_MMAP + return ALIGN_TO(size + MP_CHUNK_TAIL, CPU_PAGE_SIZE) - MP_CHUNK_TAIL; +#else + return ALIGN_TO(size, CPU_STRUCT_ALIGN); +#endif +} + +void +mp_init(struct mempool *pool, unsigned chunk_size) +{ + chunk_size = mp_align_size(MAX(sizeof(struct mempool), chunk_size)); + *pool = (struct mempool) { + .chunk_size = chunk_size, + .threshold = chunk_size >> 1, + .last_big = &pool->last_big + }; +} + +static void * +mp_new_big_chunk(unsigned size) +{ + uint8_t *data = malloc(size + MP_CHUNK_TAIL); + if (!data) { + return NULL; + } + ASAN_POISON_MEMORY_REGION(data, size); + struct mempool_chunk *chunk = (struct mempool_chunk *)(data + size); + chunk->size = size; + return chunk; +} + +static void +mp_free_big_chunk(struct mempool_chunk *chunk) +{ + void *ptr = (uint8_t *)chunk - chunk->size; + ASAN_UNPOISON_MEMORY_REGION(ptr, chunk->size); + free(ptr); +} + +static void * +mp_new_chunk(unsigned size) +{ +#ifdef CONFIG_UCW_POOL_IS_MMAP + uint8_t *data = page_alloc(size + MP_CHUNK_TAIL); + if (!data) { + return NULL; + } + ASAN_POISON_MEMORY_REGION(data, size); + struct mempool_chunk *chunk = (struct mempool_chunk *)(data + size); + chunk->size = size; + return chunk; +#else + return mp_new_big_chunk(size); +#endif +} + +static void +mp_free_chunk(struct mempool_chunk *chunk) +{ +#ifdef CONFIG_UCW_POOL_IS_MMAP + uint8_t *data = (uint8_t *)chunk - chunk->size; + ASAN_UNPOISON_MEMORY_REGION(data, chunk->size); + page_free(data, chunk->size + MP_CHUNK_TAIL); +#else + mp_free_big_chunk(chunk); +#endif +} + +struct mempool * +mp_new(unsigned chunk_size) +{ + chunk_size = mp_align_size(MAX(sizeof(struct mempool), chunk_size)); + struct mempool_chunk *chunk = mp_new_chunk(chunk_size); + struct mempool *pool = (void *)chunk - chunk_size; + ASAN_UNPOISON_MEMORY_REGION(pool, sizeof(*pool)); + DBG("Creating mempool %p with %u bytes long chunks", pool, chunk_size); + chunk->next = NULL; + ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + *pool = (struct mempool) { + .state = { .free = { chunk_size - sizeof(*pool) }, .last = { chunk } }, + .chunk_size = chunk_size, + .threshold = chunk_size >> 1, + .last_big = &pool->last_big + }; + return pool; +} + +static void +mp_free_chain(struct mempool_chunk *chunk) +{ + while (chunk) { + ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + struct mempool_chunk *next = chunk->next; + mp_free_chunk(chunk); + chunk = next; + } +} + +static void +mp_free_big_chain(struct mempool_chunk *chunk) +{ + while (chunk) { + ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + struct mempool_chunk *next = chunk->next; + mp_free_big_chunk(chunk); + chunk = next; + } +} + +void +mp_delete(struct mempool *pool) +{ + if (pool == NULL) { + return; + } + DBG("Deleting mempool %p", pool); + mp_free_big_chain(pool->state.last[1]); + mp_free_chain(pool->unused); + mp_free_chain(pool->state.last[0]); // can contain the mempool structure +} + +void +mp_flush(struct mempool *pool) +{ + mp_free_big_chain(pool->state.last[1]); + struct mempool_chunk *chunk = pool->state.last[0], *next; + while (chunk) { + ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + if ((uint8_t *)chunk - chunk->size == (uint8_t *)pool) { + break; + } + next = chunk->next; + chunk->next = pool->unused; + ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + pool->unused = chunk; + chunk = next; + } + pool->state.last[0] = chunk; + if (chunk) { + pool->state.free[0] = chunk->size - sizeof(*pool); + ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + } else { + pool->state.free[0] = 0; + } + pool->state.last[1] = NULL; + pool->state.free[1] = 0; + pool->last_big = &pool->last_big; +} + +static void +mp_stats_chain(struct mempool_chunk *chunk, struct mempool_stats *stats, unsigned idx) +{ + struct mempool_chunk *next; + while (chunk) { + ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + stats->chain_size[idx] += chunk->size + sizeof(*chunk); + stats->chain_count[idx]++; + next = chunk->next; + ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + chunk = next; + } + stats->total_size += stats->chain_size[idx]; +} + +void +mp_stats(struct mempool *pool, struct mempool_stats *stats) +{ + bzero(stats, sizeof(*stats)); + mp_stats_chain(pool->state.last[0], stats, 0); + mp_stats_chain(pool->state.last[1], stats, 1); + mp_stats_chain(pool->unused, stats, 2); +} + +uint64_t +mp_total_size(struct mempool *pool) +{ + struct mempool_stats stats; + mp_stats(pool, &stats); + return stats.total_size; +} + +static void * +mp_alloc_internal(struct mempool *pool, unsigned size) +{ + struct mempool_chunk *chunk; + if (size <= pool->threshold) { + pool->idx = 0; + if (pool->unused) { + chunk = pool->unused; + ASAN_UNPOISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + pool->unused = chunk->next; + } else { + chunk = mp_new_chunk(pool->chunk_size); + } + chunk->next = pool->state.last[0]; + ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + pool->state.last[0] = chunk; + pool->state.free[0] = pool->chunk_size - size; + return (uint8_t *)chunk - pool->chunk_size; + } else if (size <= MP_SIZE_MAX) { + pool->idx = 1; + unsigned aligned = ALIGN_TO(size, CPU_STRUCT_ALIGN); + chunk = mp_new_big_chunk(aligned); + if (!chunk) { + return NULL; + } + chunk->next = pool->state.last[1]; + ASAN_POISON_MEMORY_REGION(chunk, sizeof(struct mempool_chunk)); + pool->state.last[1] = chunk; + pool->state.free[1] = aligned - size; + return pool->last_big = (uint8_t *)chunk - aligned; + } else { + fprintf(stderr, "Cannot allocate %u bytes from a mempool", size); + assert(0); + return NULL; + } +} + +void * +mp_alloc(struct mempool *pool, unsigned size) +{ + unsigned avail = pool->state.free[0] & ~(CPU_STRUCT_ALIGN - 1); + void *ptr = NULL; + if (size <= avail) { + pool->state.free[0] = avail - size; + ptr = (uint8_t*)pool->state.last[0] - avail; + } else { + ptr = mp_alloc_internal(pool, size); + } + ASAN_UNPOISON_MEMORY_REGION(ptr, size); + return ptr; +} + +void * +mp_alloc_noalign(struct mempool *pool, unsigned size) +{ + void *ptr = NULL; + if (size <= pool->state.free[0]) { + ptr = (uint8_t*)pool->state.last[0] - pool->state.free[0]; + pool->state.free[0] -= size; + } else { + ptr = mp_alloc_internal(pool, size); + } + ASAN_UNPOISON_MEMORY_REGION(ptr, size); + return ptr; +} + +void * +mp_alloc_zero(struct mempool *pool, unsigned size) +{ + void *ptr = mp_alloc(pool, size); + bzero(ptr, size); + return ptr; +} diff --git a/utils/kresctl/deps/ucw/mempool.h b/utils/kresctl/deps/ucw/mempool.h new file mode 100644 index 000000000..c5a4fa8ea --- /dev/null +++ b/utils/kresctl/deps/ucw/mempool.h @@ -0,0 +1,124 @@ +/* + * UCW Library -- Memory Pools + * + * (c) 1997--2005 Martin Mares + * (c) 2007 Pavel Charvat + * (c) 2015, 2017 CZ.NIC, z.s.p.o. + * + * This software may be freely distributed and used according to the terms + * of the GNU Lesser General Public License. + */ + +#pragma once + +#include +#include + +#define CPU_STRUCT_ALIGN (sizeof(void*)) + +/*** + * [[defs]] + * Definitions + * ----------- + ***/ + +/** + * Memory pool state (see mp_push(), ...). + * You should use this one as an opaque handle only, the insides are internal. + **/ +struct mempool_state { + unsigned free[2]; + void *last[2]; +}; + +/** + * Memory pool. + * You should use this one as an opaque handle only, the insides are internal. + **/ +struct mempool { + struct mempool_state state; + void *unused, *last_big; + unsigned chunk_size, threshold, idx; +}; + +struct mempool_stats { /** Mempool statistics. See mp_stats(). **/ + uint64_t total_size; /** Real allocated size in bytes. */ + unsigned chain_count[3]; /** Number of allocated chunks in small/big/unused chains. */ + unsigned chain_size[3]; /** Size of allocated chunks in small/big/unused chains. */ +}; + +/*** + * [[basic]] + * Basic manipulation + * ------------------ + ***/ + +/** + * Initialize a given mempool structure. + * \p chunk_size must be in the interval `[1, UINT_MAX / 2]`. + * It will allocate memory by this large chunks and take + * memory to satisfy requests from them. + * + * Memory pools can be treated as <>, see <>. + **/ +void mp_init(struct mempool *pool, unsigned chunk_size); + +/** + * Allocate and initialize a new memory pool. + * See \ref mp_init() for \p chunk_size limitations. + * + * The new mempool structure is allocated on the new mempool. + * + * Memory pools can be treated as <>, see <>. + **/ +struct mempool *mp_new(unsigned chunk_size); + +/** + * Cleanup mempool initialized by mp_init or mp_new. + * Frees all the memory allocated by this mempool and, + * if created by \ref mp_new(), the \p pool itself. + **/ +void mp_delete(struct mempool *pool); + +/** + * Frees all data on a memory pool, but leaves it working. + * It can keep some of the chunks allocated to serve + * further allocation requests. Leaves the \p pool alive, + * even if it was created with \ref mp_new(). + **/ +void mp_flush(struct mempool *pool); + +/** + * Compute some statistics for debug purposes. + * See the definition of the <>. + **/ +void mp_stats(struct mempool *pool, struct mempool_stats *stats); +uint64_t mp_total_size(struct mempool *pool); /** How many bytes were allocated by the pool. **/ + +/*** + * [[alloc]] + * Allocation routines + * ------------------- + ***/ + +/** + * The function allocates new \p size bytes on a given memory pool. + * If the \p size is zero, the resulting pointer is undefined, + * but it may be safely reallocated or used as the parameter + * to other functions below. + * + * The resulting pointer is always aligned to a multiple of + * `CPU_STRUCT_ALIGN` bytes and this condition remains true also + * after future reallocations. + **/ +void *mp_alloc(struct mempool *pool, unsigned size); + +/** + * The same as \ref mp_alloc(), but the result may be unaligned. + **/ +void *mp_alloc_noalign(struct mempool *pool, unsigned size); + +/** + * The same as \ref mp_alloc(), but fills the newly allocated memory with zeroes. + **/ +void *mp_alloc_zero(struct mempool *pool, unsigned size); diff --git a/utils/kresctl/interactive.c b/utils/kresctl/interactive.c new file mode 100644 index 000000000..fa53d31ac --- /dev/null +++ b/utils/kresctl/interactive.c @@ -0,0 +1,170 @@ +#include +#include +#include +#include +#include + +#include "lib/generic/array.h" +#include + +#include "commands.h" +#include "interactive.h" +#include "process.h" +#include "deps/lookup.h" + + +static void cmds_lookup(EditLine *el, const char *str, size_t str_len) +{ + lookup_t lookup; + int ret = lookup_init(&lookup); + if (ret != CLI_EOK) { + return; + } + + /* Fill the lookup with command names of static cmds. */ + for (const cmd_help_t *cmd_help = cmd_table; cmd_help->name != NULL; cmd_help++) { + ret = lookup_insert(&lookup, cmd_help->name, NULL); + if (ret != CLI_EOK) { + goto cmds_lookup_finish; + } + } + + /* Fill the lookup with command names of dynamic cmds. */ + dynarray_foreach(cmd_help, cmd_help_t *, i, dyn_cmd_help_table) { + cmd_help_t *cmd_help = *i; + ret = lookup_insert(&lookup, cmd_help->name, NULL); + if (ret != CLI_EOK) { + goto cmds_lookup_finish; + } + } + + lookup_complete(&lookup, str, str_len, el, true); + + cmds_lookup_finish: + lookup_deinit(&lookup); +} + +static unsigned char complete(EditLine *el, int ch) +{ + int argc, token, pos; + const char **argv; + + const LineInfo *li = el_line(el); + Tokenizer *tok = tok_init(NULL); + + /* Parse the line. */ + int ret = tok_line(tok, li, &argc, &argv, &token, &pos); + if (ret != 0) { + goto complete_exit; + } + + /* Show possible commands. */ + if (argc == 0) { + print_commands(NULL); + goto complete_exit; + } + + /* Complete the command name. */ + if (token == 0) { + cmds_lookup(el, argv[0], pos); + goto complete_exit; + } + + /* Find the command descriptor. */ + const cmd_desc_t *desc = cmd_table; + while (desc->name != NULL && strcmp(desc->name, argv[0]) != 0) { + desc++; + } + if (desc->name == NULL) { + goto complete_exit; + } + + complete_exit: + tok_reset(tok); + tok_end(tok); + + return CC_REDISPLAY; +} + +static char *prompt(EditLine *el) +{ + return PROGRAM_NAME"> "; +} + +int interactive_loop(params_t *process_params) +{ + char *hist_file = NULL; + const char *home = getenv("HOME"); + if (home != NULL) { + asprintf(&hist_file, "%s/"HISTORY_FILE, home); + } + if (hist_file == NULL) { + printf("failed to get home directory"); + } + + EditLine *el = el_init(PROGRAM_NAME, stdin, stdout, stderr); + if (el == NULL) { + printf("interactive mode not available"); + free(hist_file); + return 1; + } + + History *hist = history_init(); + if (hist == NULL) { + printf("interactive mode not available"); + el_end(el); + free(hist_file); + return 1; + } + + HistEvent hev = { 0 }; + history(hist, &hev, H_SETSIZE, 100); + el_set(el, EL_HIST, history, hist); + history(hist, &hev, H_LOAD, hist_file); + + el_set(el, EL_TERMINAL, NULL); + el_set(el, EL_EDITOR, "emacs"); + el_set(el, EL_PROMPT, prompt); + el_set(el, EL_SIGNAL, 1); + el_source(el, NULL); + + el_set(el, EL_ADDFN, PROGRAM_NAME"-complete", + "Perform "PROGRAM_NAME" completion.", complete); + el_set(el, EL_BIND, "^I", PROGRAM_NAME"-complete", NULL); + + int count; + const char *line; + while ((line = el_gets(el, &count)) != NULL && count > 0) { + history(hist, &hev, H_ENTER, line); + + Tokenizer *tok = tok_init(NULL); + + /* Tokenize the current line. */ + int argc; + const char **argv; + const LineInfo *li = el_line(el); + int ret = tok_line(tok, li, &argc, &argv, NULL, NULL); + if (ret != 0) { + continue; + } + + /* Process the command. */ + ret = process_cmd(argc, argv, process_params); + + history(hist, &hev, H_SAVE, hist_file); + tok_reset(tok); + tok_end(tok); + + /* Check for the exit command. */ + if (ret == CLI_EXIT) { + break; + } + } + + history_end(hist); + free(hist_file); + + el_end(el); + + return 0; +} \ No newline at end of file diff --git a/utils/kresctl/interactive.h b/utils/kresctl/interactive.h new file mode 100644 index 000000000..a91c69331 --- /dev/null +++ b/utils/kresctl/interactive.h @@ -0,0 +1,7 @@ +#pragma once + +#include + + +/** CLI interactive loop */ +int interactive_loop(params_t *params); \ No newline at end of file diff --git a/utils/kresctl/main.c b/utils/kresctl/main.c index d1f9a9643..7ba8bedfd 100644 --- a/utils/kresctl/main.c +++ b/utils/kresctl/main.c @@ -1,4 +1,123 @@ +#include +#include +#include + +#include "kresconfig.h" +#include "interactive.h" +#include "process.h" +#include "commands.h" + + +params_t params = { + .timeout = SYSREPO_TIMEOUT * 1000, +}; + +sysrepo_ctx_t *sysrepo_ctx; +sysrepo_ctx_t sysrepo_ctx_value = { + .connection = NULL, + .session = NULL +}; + +static void print_help(void) +{ + print_version(NULL); + + printf("\nUsage:\n" + " %s [parameters] [command-arguments]\n" + "\n" + "Parameters:\n" + " -t, --timeout "SPACE"Timeout for sysrepo operations.\n" + " "SPACE" (default %d seconds)\n" + " -h, --help "SPACE"Print the program help.\n" + " -V, --version "SPACE"Print the program version.\n", + PROGRAM_NAME, SYSREPO_TIMEOUT); + + print_commands(NULL); +} + int main(int argc, char *argv[]) { - return 0; + /* Long options. */ + struct option opts[] = { + { "timeout", required_argument, NULL, 't' }, + { "help", no_argument, NULL, 'h' }, + { "version", no_argument, NULL, 'V' }, + { NULL } + }; + + /* Init sysrepo connection */ + sysrepo_ctx = &sysrepo_ctx_value; + int ret = sr_connect(0, &sysrepo_ctx->connection); + if (ret){ + printf("[kresctl] failed to connect to sysrepo: %s\n", + sr_strerror(ret)); + goto sr_cleanup; + } + + /* Create dynamic commands table */ + ret = create_cmd_table(sysrepo_ctx->connection); + if (ret){ + printf("[kresctl] failed to create commands table\n"); + goto sr_cleanup; + } + + /* Parse command line parameters */ + int opt = 0; + while ((opt = getopt_long(argc, argv, "+t:hV", opts, NULL)) != -1) { + switch (opt) { + case 't': + params.timeout = atoi(optarg); + if (params.timeout < 1) { + printf("[kresctl] error '-t' requires a positive" + " number, not '%s'\n", optarg); + return EXIT_FAILURE; + } + /* Convert to milliseconds. */ + params.timeout *= 1000; + break; + case 'h': + print_help(); + return EXIT_SUCCESS; + case 'V': + print_version(NULL); + return EXIT_SUCCESS; + default: + print_help(); + return EXIT_FAILURE; + } + } + + if (argc - optind < 1) { + /* start interactive loop */ + ret = interactive_loop(¶ms); + } else { + /* + * Session with RUNNING datastore + * needs to be created here, because there + * is no interactive loop to create + * transaction with candidate datastore. + */ + ret = sr_session_start(sysrepo_ctx->connection, SR_DS_RUNNING, &sysrepo_ctx->session); + if (ret) { + printf("failed to start sysrepo session, %s\n", sr_strerror(ret)); + goto cleanup; + } + /* execute commands added from terminal */ + if (!ret) ret = process_cmd(argc - optind, (const char **)argv + optind, ¶ms); + + /* Stop sysrepo session. */ + ret = sr_session_stop(sysrepo_ctx->session); + if (ret) { + printf("failed to stop sysrepo session, %s\n", sr_strerror(ret)); + } + } + +cleanup: + /* free all dynamic tables */ + destroy_cmd_table(); + +sr_cleanup: + sr_disconnect(sysrepo_ctx->connection); + + return (ret == CLI_EOK) ? EXIT_SUCCESS : EXIT_FAILURE; } \ No newline at end of file diff --git a/utils/kresctl/meson.build b/utils/kresctl/meson.build index 583548df5..89e887bd3 100644 --- a/utils/kresctl/meson.build +++ b/utils/kresctl/meson.build @@ -1,20 +1,31 @@ kresctl_src = files([ 'main.c', + 'process.c', + 'interactive.c', + 'commands.c', + 'conf_file.c', + 'deps/lookup.c', + 'deps/mempattern.c', + 'deps/string.c', + 'deps/mempattern.c', ]) c_src_lint += kresctl_src -if build_sysrepo - kresc = executable( +message('--- kresctl dependencies ---') +libedit = dependency('libedit', required: false) + +if build_sysrepo and libedit.found() + kresctl = executable( 'kresctl', kresctl_src, sysrepo_common_src, dependencies: [ contrib_dep, libkres_dep, + libedit, libyang, libsysrepo, libknot, - luajit_inc, ], install: true, install_dir: get_option('sbindir'), diff --git a/utils/kresctl/process.c b/utils/kresctl/process.c new file mode 100644 index 000000000..530e69d02 --- /dev/null +++ b/utils/kresctl/process.c @@ -0,0 +1,66 @@ +#include +#include + +#include "contrib/dynarray.h" +#include "process.h" +#include "commands.h" + + +static const cmd_desc_t *get_cmd_desc(const char *command) +{ + /* Try to find requested command in buid-in commands. */ + const cmd_desc_t *desc = cmd_table; + while (desc->name != NULL) { + if (strcmp(desc->name, command) == 0) { + break; + } + desc++; + } + + /* Try to find requested command in created commands. */ + dynarray_foreach(cmd, cmd_desc_t *, i, dyn_cmd_table) { + cmd_desc_t *dyn_desc = *i; + if (strcmp(dyn_desc->name, command) == 0) { + desc = dyn_desc; + break; + } + } + + if (desc->name == NULL) { + printf("invalid command '%s'\n", command); + return NULL; + } + + return desc; +} + +int process_cmd(int argc, const char **argv, params_t *params) +{ + if (argc == 0) { + return ENOTSUP; + } + + /* Check the command name. */ + const cmd_desc_t *desc = get_cmd_desc(argv[0]); + if (desc == NULL) { + return ENOENT; + } + + /* Check for program exit. */ + if (desc->fcn == NULL) { + return CLI_EXIT; + } + + /* Prepare command arguments. */ + cmd_args_t args = { + .desc = desc, + .argc = argc - 1, + .argv = argv + 1, + .timeout = params->timeout, + }; + + /* Execute the command. */ + int ret = desc->fcn(&args); + + return ret; +} \ No newline at end of file diff --git a/utils/kresctl/process.h b/utils/kresctl/process.h new file mode 100644 index 000000000..b34110030 --- /dev/null +++ b/utils/kresctl/process.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include + +#include "kresconfig.h" +#include "process.h" +#include "commands.h" + +/* CLI globals */ +#define PROJECT_NAME "Knot Resolver" +#define PROGRAM_NAME "kresctl" +#define PROGRAM_DESC "control/administration tool" +#define HISTORY_FILE ".kresctl_history" +#define SPACE " " +/* default values */ +#define SYSREPO_TIMEOUT 10 +/* return codes */ +#define CLI_EOK 0 +#define CLI_ERR 1 +#define CLI_ECMD 2 +#define CLI_EXIT -10 + + +/* CLI parameters. */ +typedef struct { + int timeout; + int max_depth; +} params_t; + +/* CLI context */ +typedef struct { + sr_conn_ctx_t *connection; + sr_session_ctx_t *session; +} sysrepo_ctx_t; + +extern sysrepo_ctx_t *sysrepo_ctx; + +int process_cmd(int argc, const char **argv, params_t *params); diff --git a/utils/meson.build b/utils/meson.build index 66e63d339..7185f1e55 100644 --- a/utils/meson.build +++ b/utils/meson.build @@ -3,6 +3,6 @@ build_utils = get_option('utils') != 'disabled' subdir('kresctl') -subdir('kres_watcher') +subdir('watcher') subdir('client') subdir('cache_gc') diff --git a/utils/watcher/bindings/api.h b/utils/watcher/bindings/api.h new file mode 100644 index 000000000..1e7475147 --- /dev/null +++ b/utils/watcher/bindings/api.h @@ -0,0 +1,24 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include + +/** Make all the bindings accessible from the lua state, + * .i.e. define those lua tables. */ +void kr_bindings_register(lua_State *L); + diff --git a/utils/watcher/bindings/cache.c b/utils/watcher/bindings/cache.c new file mode 100644 index 000000000..68f1cfb9c --- /dev/null +++ b/utils/watcher/bindings/cache.c @@ -0,0 +1,467 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "bindings/impl.h" + +#include "worker.h" +#include "zimport.h" + +/** @internal return cache, or throw lua error if not open */ +struct kr_cache * cache_assert_open(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + struct kr_cache *cache = &engine->resolver.cache; + assert(cache); + if (!cache || !kr_cache_is_open(cache)) + lua_error_p(L, "no cache is open yet, use cache.open() or cache.size, etc."); + return cache; +} + +/** Return available cached backends. */ +static int cache_backends(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + + lua_newtable(L); + for (unsigned i = 0; i < engine->backends.len; ++i) { + const struct kr_cdb_api *api = engine->backends.at[i]; + lua_pushboolean(L, api == engine->resolver.cache.api); + lua_setfield(L, -2, api->name); + } + return 1; +} + +/** Return number of cached records. */ +static int cache_count(lua_State *L) +{ + struct kr_cache *cache = cache_assert_open(L); + + int count = cache->api->count(cache->db, &cache->stats); + if (count >= 0) { + /* First key is a version counter, omit it if nonempty. */ + lua_pushinteger(L, count ? count - 1 : 0); + return 1; + } + return 0; +} + +/** Return time of last checkpoint, or re-set it if passed `true`. */ +static int cache_checkpoint(lua_State *L) +{ + struct kr_cache *cache = cache_assert_open(L); + + if (lua_gettop(L) == 0) { /* Return the current value. */ + lua_newtable(L); + lua_pushnumber(L, cache->checkpoint_monotime); + lua_setfield(L, -2, "monotime"); + lua_newtable(L); + lua_pushnumber(L, cache->checkpoint_walltime.tv_sec); + lua_setfield(L, -2, "sec"); + lua_pushnumber(L, cache->checkpoint_walltime.tv_usec); + lua_setfield(L, -2, "usec"); + lua_setfield(L, -2, "walltime"); + return 1; + } + + if (lua_gettop(L) != 1 || !lua_isboolean(L, 1) || !lua_toboolean(L, 1)) + lua_error_p(L, "cache.checkpoint() takes no parameters or a true value"); + + kr_cache_make_checkpoint(cache); + return 1; +} + +/** Return cache statistics. */ +static int cache_stats(lua_State *L) +{ + struct kr_cache *cache = cache_assert_open(L); + lua_newtable(L); +#define add_stat(name) \ + lua_pushinteger(L, (cache->stats.name)); \ + lua_setfield(L, -2, #name) + add_stat(open); + add_stat(close); + add_stat(count); + add_stat(clear); + add_stat(commit); + add_stat(read); + add_stat(read_miss); + add_stat(write); + add_stat(remove); + add_stat(remove_miss); + add_stat(match); + add_stat(match_miss); + add_stat(read_leq); + add_stat(read_leq_miss); +#undef add_stat + + return 1; +} + +static const struct kr_cdb_api *cache_select(struct engine *engine, const char **conf) +{ + /* Return default backend */ + if (*conf == NULL || !strstr(*conf, "://")) { + return engine->backends.at[0]; + } + + /* Find storage backend from config prefix */ + for (unsigned i = 0; i < engine->backends.len; ++i) { + const struct kr_cdb_api *api = engine->backends.at[i]; + if (strncmp(*conf, api->name, strlen(api->name)) == 0) { + *conf += strlen(api->name) + strlen("://"); + return api; + } + } + + return NULL; +} + +static int cache_max_ttl(lua_State *L) +{ + struct kr_cache *cache = cache_assert_open(L); + + int n = lua_gettop(L); + if (n > 0) { + if (!lua_isnumber(L, 1) || n > 1) + lua_error_p(L, "expected 'max_ttl(number ttl)'"); + uint32_t min = cache->ttl_min; + int64_t ttl = lua_tointeger(L, 1); + if (ttl < 1 || ttl < min || ttl > UINT32_MAX) { + lua_error_p(L, + "max_ttl must be larger than minimum TTL, and in range <1, " + STR(UINT32_MAX) ">'"); + } + cache->ttl_max = ttl; + } + lua_pushinteger(L, cache->ttl_max); + return 1; +} + + +static int cache_min_ttl(lua_State *L) +{ + struct kr_cache *cache = cache_assert_open(L); + + int n = lua_gettop(L); + if (n > 0) { + if (!lua_isnumber(L, 1)) + lua_error_p(L, "expected 'min_ttl(number ttl)'"); + uint32_t max = cache->ttl_max; + int64_t ttl = lua_tointeger(L, 1); + if (ttl < 0 || ttl > max || ttl > UINT32_MAX) { + lua_error_p(L, + "min_ttl must be smaller than maximum TTL, and in range <0, " + STR(UINT32_MAX) ">'"); + } + cache->ttl_min = ttl; + } + lua_pushinteger(L, cache->ttl_min); + return 1; +} + +/** Open cache */ +static int cache_open(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n < 1 || !lua_isnumber(L, 1)) + lua_error_p(L, "expected 'open(number max_size, string config = \"\")'"); + + /* Select cache storage backend */ + struct engine *engine = engine_luaget(L); + + lua_Integer csize_lua = lua_tointeger(L, 1); + if (!(csize_lua >= 8192 && csize_lua < SIZE_MAX)) { /* min. is basically arbitrary */ + lua_error_p(L, "invalid cache size specified, it must be in range <8192, " + STR(SIZE_MAX) ">"); + } + size_t cache_size = csize_lua; + + const char *conf = n > 1 ? lua_tostring(L, 2) : NULL; + const char *uri = conf; + const struct kr_cdb_api *api = cache_select(engine, &conf); + if (!api) + lua_error_p(L, "unsupported cache backend"); + + /* Close if already open */ + kr_cache_close(&engine->resolver.cache); + + /* Reopen cache */ + struct kr_cdb_opts opts = { + (conf && strlen(conf)) ? conf : ".", + cache_size + }; + int ret = kr_cache_open(&engine->resolver.cache, api, &opts, engine->pool); + if (ret != 0) { + char cwd[PATH_MAX]; + get_workdir(cwd, sizeof(cwd)); + return luaL_error(L, "can't open cache path '%s'; working directory '%s'", opts.path, cwd); + } + + /* Store current configuration */ + lua_getglobal(L, "cache"); + lua_pushstring(L, "current_size"); + lua_pushnumber(L, cache_size); + lua_rawset(L, -3); + lua_pushstring(L, "current_storage"); + lua_pushstring(L, uri); + lua_rawset(L, -3); + lua_pop(L, 1); + + lua_pushboolean(L, 1); + return 1; +} + +static int cache_close(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + struct kr_cache *cache = &engine->resolver.cache; + if (!kr_cache_is_open(cache)) { + return 0; + } + + kr_cache_close(cache); + lua_getglobal(L, "cache"); + lua_pushstring(L, "current_size"); + lua_pushnumber(L, 0); + lua_rawset(L, -3); + lua_pop(L, 1); + lua_pushboolean(L, 1); + return 1; +} + +#if 0 +/** @internal Prefix walk. */ +static int cache_prefixed(struct kr_cache *cache, const char *prefix, bool exact_name, + knot_db_val_t keyval[][2], int maxcount) +{ + /* Convert to domain name */ + uint8_t buf[KNOT_DNAME_MAXLEN]; + if (!knot_dname_from_str(buf, prefix, sizeof(buf))) { + return kr_error(EINVAL); + } + /* Start prefix search */ + return kr_cache_match(cache, buf, exact_name, keyval, maxcount); +} +#endif + +/** Clear everything. */ +static int cache_clear_everything(lua_State *L) +{ + struct kr_cache *cache = cache_assert_open(L); + + /* Clear records and packets. */ + int ret = kr_cache_clear(cache); + lua_error_maybe(L, ret); + + /* Clear reputation tables */ + struct engine *engine = engine_luaget(L); + lru_reset(engine->resolver.cache_rtt); + lru_reset(engine->resolver.cache_rep); + lru_reset(engine->resolver.cache_cookie); + lua_pushboolean(L, true); + return 1; +} + +#if 0 +/** @internal Dump cache key into table on Lua stack. */ +static void cache_dump(lua_State *L, knot_db_val_t keyval[]) +{ + knot_dname_t dname[KNOT_DNAME_MAXLEN]; + char name[KNOT_DNAME_TXT_MAXLEN]; + uint16_t type; + + int ret = kr_unpack_cache_key(keyval[0], dname, &type); + if (ret < 0) { + return; + } + + ret = !knot_dname_to_str(name, dname, sizeof(name)); + assert(!ret); + if (ret) return; + + /* If name typemap doesn't exist yet, create it */ + lua_getfield(L, -1, name); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + lua_newtable(L); + } + /* Append to typemap */ + char type_buf[KR_RRTYPE_STR_MAXLEN] = { '\0' }; + knot_rrtype_to_string(type, type_buf, sizeof(type_buf)); + lua_pushboolean(L, true); + lua_setfield(L, -2, type_buf); + /* Set name typemap */ + lua_setfield(L, -2, name); +} + +/** Query cached records. TODO: fix caveats in ./README.rst documentation? */ +static int cache_get(lua_State *L) +{ + //struct kr_cache *cache = cache_assert_open(L); // to be fixed soon + + /* Check parameters */ + int n = lua_gettop(L); + if (n < 1 || !lua_isstring(L, 1)) + lua_error_p(L, "expected 'cache.get(string key)'"); + + /* Retrieve set of keys */ + const char *prefix = lua_tostring(L, 1); + knot_db_val_t keyval[100][2]; + int ret = cache_prefixed(cache, prefix, false/*FIXME*/, keyval, 100); + lua_error_maybe(L, ret); + /* Format output */ + lua_newtable(L); + for (int i = 0; i < ret; ++i) { + cache_dump(L, keyval[i]); + } + return 1; +} +#endif +static int cache_get(lua_State *L) +{ + lua_error_maybe(L, ENOSYS); + return kr_error(ENOSYS); /* doesn't happen */ +} + +/** Set time interval for cleaning rtt cache. + * Servers with score >= KR_NS_TIMEOUT will be cleaned after + * this interval ended up, so that they will be able to participate + * in NS elections again. */ +static int cache_ns_tout(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + struct kr_context *ctx = &engine->resolver; + + /* Check parameters */ + int n = lua_gettop(L); + if (n < 1) { + lua_pushinteger(L, ctx->cache_rtt_tout_retry_interval); + return 1; + } + + if (!lua_isnumber(L, 1)) + lua_error_p(L, "expected 'cache.ns_tout(interval in ms)'"); + + lua_Integer interval_lua = lua_tointeger(L, 1); + if (!(interval_lua > 0 && interval_lua < UINT_MAX)) { + lua_error_p(L, "invalid interval specified, it must be in range > 0, < " + STR(UINT_MAX)); + } + + ctx->cache_rtt_tout_retry_interval = interval_lua; + lua_pushinteger(L, ctx->cache_rtt_tout_retry_interval); + return 1; +} + +/** Zone import completion callback. + * Deallocates zone import context. */ +static void cache_zone_import_cb(int state, void *param) +{ + assert (param); + (void)state; + struct worker_ctx *worker = (struct worker_ctx *)param; + assert (worker->z_import); + zi_free(worker->z_import); + worker->z_import = NULL; +} + +/** Import zone from file. */ +static int cache_zone_import(lua_State *L) +{ + int ret = -1; + char msg[128]; + + struct worker_ctx *worker = the_worker; + if (!worker) { + strncpy(msg, "internal error, empty worker pointer", sizeof(msg)); + goto finish; + } + + if (worker->z_import && zi_import_started(worker->z_import)) { + strncpy(msg, "import already started", sizeof(msg)); + goto finish; + } + + (void)cache_assert_open(L); /* just check it in advance */ + + /* Check parameters */ + int n = lua_gettop(L); + if (n < 1 || !lua_isstring(L, 1)) { + strncpy(msg, "expected 'cache.zone_import(path to zone file)'", sizeof(msg)); + goto finish; + } + + /* Parse zone file */ + const char *zone_file = lua_tostring(L, 1); + + const char *default_origin = NULL; /* TODO */ + uint16_t default_rclass = 1; + uint32_t default_ttl = 0; + + if (worker->z_import == NULL) { + worker->z_import = zi_allocate(worker, cache_zone_import_cb, worker); + if (worker->z_import == NULL) { + strncpy(msg, "can't allocate zone import context", sizeof(msg)); + goto finish; + } + } + + ret = zi_zone_import(worker->z_import, zone_file, default_origin, + default_rclass, default_ttl); + + lua_newtable(L); + if (ret == 0) { + strncpy(msg, "zone file successfully parsed, import started", sizeof(msg)); + } else if (ret == 1) { + strncpy(msg, "TA not found", sizeof(msg)); + } else { + strncpy(msg, "error parsing zone file", sizeof(msg)); + } + +finish: + msg[sizeof(msg) - 1] = 0; + lua_newtable(L); + lua_pushstring(L, msg); + lua_setfield(L, -2, "msg"); + lua_pushnumber(L, ret); + lua_setfield(L, -2, "code"); + + return 1; +} + +int kr_bindings_cache(lua_State *L) +{ + static const luaL_Reg lib[] = { + { "backends", cache_backends }, + { "count", cache_count }, + { "stats", cache_stats }, + { "checkpoint", cache_checkpoint }, + { "open", cache_open }, + { "close", cache_close }, + { "clear_everything", cache_clear_everything }, + { "get", cache_get }, + { "max_ttl", cache_max_ttl }, + { "min_ttl", cache_min_ttl }, + { "ns_tout", cache_ns_tout }, + { "zone_import", cache_zone_import }, + { NULL, NULL } + }; + + luaL_register(L, "cache", lib); + return 1; +} + diff --git a/utils/watcher/bindings/event.c b/utils/watcher/bindings/event.c new file mode 100644 index 000000000..ffc9aefdc --- /dev/null +++ b/utils/watcher/bindings/event.c @@ -0,0 +1,225 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "bindings/impl.h" + +#include "worker.h" + +#include +#include + +static void event_free(uv_timer_t *timer) +{ + struct worker_ctx *worker = timer->loop->data; + lua_State *L = worker->engine->L; + int ref = (intptr_t) timer->data; + luaL_unref(L, LUA_REGISTRYINDEX, ref); + free(timer); +} + +static void event_callback(uv_timer_t *timer) +{ + struct worker_ctx *worker = timer->loop->data; + lua_State *L = worker->engine->L; + + /* Retrieve callback and execute */ + lua_rawgeti(L, LUA_REGISTRYINDEX, (intptr_t) timer->data); + lua_rawgeti(L, -1, 1); + lua_pushinteger(L, (intptr_t) timer->data); + int ret = execute_callback(L, 1); + /* Free callback if not recurrent or an error */ + if (ret != 0 || (uv_timer_get_repeat(timer) == 0 && uv_is_active((uv_handle_t *)timer) == 0)) { + if (!uv_is_closing((uv_handle_t *)timer)) { + uv_close((uv_handle_t *)timer, (uv_close_cb) event_free); + } + } +} + +static void event_fdcallback(uv_poll_t* handle, int status, int events) +{ + struct worker_ctx *worker = handle->loop->data; + lua_State *L = worker->engine->L; + + /* Retrieve callback and execute */ + lua_rawgeti(L, LUA_REGISTRYINDEX, (intptr_t) handle->data); + lua_rawgeti(L, -1, 1); + lua_pushinteger(L, (intptr_t) handle->data); + lua_pushinteger(L, status); + lua_pushinteger(L, events); + int ret = execute_callback(L, 3); + /* Free callback if not recurrent or an error */ + if (ret != 0) { + if (!uv_is_closing((uv_handle_t *)handle)) { + uv_close((uv_handle_t *)handle, (uv_close_cb) event_free); + } + } +} + +static int event_sched(lua_State *L, unsigned timeout, unsigned repeat) +{ + uv_timer_t *timer = malloc(sizeof(*timer)); + if (!timer) + lua_error_p(L, "out of memory"); + + /* Start timer with the reference */ + uv_loop_t *loop = uv_default_loop(); + uv_timer_init(loop, timer); + int ret = uv_timer_start(timer, event_callback, timeout, repeat); + if (ret != 0) { + free(timer); + lua_error_p(L, "couldn't start the event"); + } + + /* Save callback and timer in registry */ + lua_newtable(L); + lua_pushvalue(L, 2); + lua_rawseti(L, -2, 1); + lua_pushpointer(L, timer); + lua_rawseti(L, -2, 2); + int ref = luaL_ref(L, LUA_REGISTRYINDEX); + + /* Save reference to the timer */ + timer->data = (void *) (intptr_t)ref; + lua_pushinteger(L, ref); + return 1; +} + +static int event_after(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n < 2 || !lua_isnumber(L, 1) || !lua_isfunction(L, 2)) + lua_error_p(L, "expected 'after(number timeout, function)'"); + + return event_sched(L, lua_tointeger(L, 1), 0); +} + +static int event_recurrent(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n < 2 || !lua_isnumber(L, 1) || !lua_isfunction(L, 2)) + lua_error_p(L, "expected 'recurrent(number interval, function)'"); + + return event_sched(L, 0, lua_tointeger(L, 1)); +} + +static int event_cancel(lua_State *L) +{ + int n = lua_gettop(L); + if (n < 1 || !lua_isnumber(L, 1)) + lua_error_p(L, "expected 'cancel(number event)'"); + + /* Fetch event if it exists */ + lua_rawgeti(L, LUA_REGISTRYINDEX, lua_tointeger(L, 1)); + bool ok = lua_istable(L, -1); + + /* Close the timer */ + uv_handle_t **timer_pp = NULL; + if (ok) { + lua_rawgeti(L, -1, 2); + timer_pp = lua_touserdata(L, -1); + ok = timer_pp && *timer_pp; + /* That have been sufficient safety checks, hopefully. */ + } + if (ok && !uv_is_closing(*timer_pp)) { + uv_close(*timer_pp, (uv_close_cb)event_free); + } + lua_pushboolean(L, ok); + return 1; +} + +static int event_reschedule(lua_State *L) +{ + int n = lua_gettop(L); + if (n < 2 || !lua_isnumber(L, 1) || !lua_isnumber(L, 2)) + lua_error_p(L, "expected 'reschedule(number event, number timeout)'"); + + /* Fetch event if it exists */ + lua_rawgeti(L, LUA_REGISTRYINDEX, lua_tointeger(L, 1)); + bool ok = lua_istable(L, -1); + + /* Reschedule the timer */ + uv_handle_t **timer_pp = NULL; + if (ok) { + lua_rawgeti(L, -1, 2); + timer_pp = lua_touserdata(L, -1); + ok = timer_pp && *timer_pp; + /* That have been sufficient safety checks, hopefully. */ + } + if (ok && !uv_is_closing(*timer_pp)) { + int ret = uv_timer_start((uv_timer_t *)*timer_pp, + event_callback, lua_tointeger(L, 2), 0); + if (ret != 0) { + uv_close(*timer_pp, (uv_close_cb)event_free); + ok = false; + } + } + lua_pushboolean(L, ok); + return 1; +} + +static int event_fdwatch(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n < 2 || !lua_isnumber(L, 1) || !lua_isfunction(L, 2)) + lua_error_p(L, "expected 'socket(number fd, function)'"); + + uv_poll_t *handle = malloc(sizeof(*handle)); + if (!handle) + lua_error_p(L, "out of memory"); + + /* Start timer with the reference */ + int sock = lua_tointeger(L, 1); + uv_loop_t *loop = uv_default_loop(); + int ret = uv_poll_init(loop, handle, sock); + if (ret == 0) + ret = uv_poll_start(handle, UV_READABLE, event_fdcallback); + if (ret != 0) { + free(handle); + lua_error_p(L, "couldn't start event poller"); + } + + /* Save callback and timer in registry */ + lua_newtable(L); + lua_pushvalue(L, 2); + lua_rawseti(L, -2, 1); + lua_pushpointer(L, handle); + lua_rawseti(L, -2, 2); + int ref = luaL_ref(L, LUA_REGISTRYINDEX); + + /* Save reference to the timer */ + handle->data = (void *) (intptr_t)ref; + lua_pushinteger(L, ref); + return 1; +} + +int kr_bindings_event(lua_State *L) +{ + static const luaL_Reg lib[] = { + { "after", event_after }, + { "recurrent", event_recurrent }, + { "cancel", event_cancel }, + { "socket", event_fdwatch }, + { "reschedule", event_reschedule }, + { NULL, NULL } + }; + + luaL_register(L, "event", lib); + return 1; +} + diff --git a/utils/watcher/bindings/impl.c b/utils/watcher/bindings/impl.c new file mode 100644 index 000000000..624aece45 --- /dev/null +++ b/utils/watcher/bindings/impl.c @@ -0,0 +1,80 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include +#include +#include + + +const char * lua_table_checkindices(lua_State *L, const char *keys[]) +{ + /* Iterate over table at the top of the stack. + * http://www.lua.org/manual/5.1/manual.html#lua_next */ + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + lua_pop(L, 1); /* we don't need the value */ + /* We need to copy the key, as _tostring() confuses _next(). + * https://www.lua.org/manual/5.1/manual.html#lua_tolstring */ + lua_pushvalue(L, -1); + const char *key = lua_tostring(L, -1); + if (!key) + return ""; + for (const char **k = keys; ; ++k) { + if (*k == NULL) + return key; + if (strcmp(*k, key) == 0) + break; + } + } + return NULL; +} + + +/* Each of these just creates the correspondingly named lua table of functions. */ +int kr_bindings_cache (lua_State *L); /* ./cache.c */ +int kr_bindings_event (lua_State *L); /* ./event.c */ +int kr_bindings_modules (lua_State *L); /* ./modules.c */ +int kr_bindings_net (lua_State *L); /* ./net.c */ +int kr_bindings_worker (lua_State *L); /* ./worker.c */ + +void kr_bindings_register(lua_State *L) +{ + kr_bindings_cache(L); + kr_bindings_event(L); + kr_bindings_modules(L); + kr_bindings_net(L); + kr_bindings_worker(L); +} + +void lua_error_p(lua_State *L, const char *fmt, ...) +{ + /* Add a stack trace and throw the result as a lua error. */ + luaL_traceback(L, L, "error occured here (config filename:lineno is at the bottom, if config is involved):", 0); + /* Push formatted custom message, prepended with "ERROR: ". */ + lua_pushliteral(L, "\nERROR: "); + { + va_list args; + va_start(args, fmt); + lua_pushvfstring(L, fmt, args); + va_end(args); + } + lua_concat(L, 3); + lua_error(L); + /* TODO: we might construct a little more friendly trace by using luaL_where(). + * In particular, in case the error happens in a function that was called + * directly from a config file (the most common case), there isn't much need + * to format the trace in this heavy way. */ +} + diff --git a/utils/watcher/bindings/impl.h b/utils/watcher/bindings/impl.h new file mode 100644 index 000000000..180360b80 --- /dev/null +++ b/utils/watcher/bindings/impl.h @@ -0,0 +1,95 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include "engine.h" + +#include +#include + +/** Useful to stringify #defines into error strings. */ +#define STR(s) STRINGIFY_TOKEN(s) +#define STRINGIFY_TOKEN(s) #s + + +/** Check lua table at the top of the stack for allowed keys. + * \param keys NULL-terminated array of 0-terminated strings + * \return NULL if passed or the offending string (pushed on top of lua stack) + * \note Future work: if non-NULL is returned, there's extra stuff on the lua stack. + * \note Brute-force complexity: table length * summed length of keys. + */ +const char * lua_table_checkindices(lua_State *L, const char *keys[]); + +/** If the value at the top of the stack isn't a table, make it a single-element list. */ +static inline void lua_listify(lua_State *L) +{ + if (lua_istable(L, -1)) + return; + lua_createtable(L, 1, 0); + lua_insert(L, lua_gettop(L) - 1); /* swap the top two stack elements */ + lua_pushinteger(L, 1); + lua_insert(L, lua_gettop(L) - 1); /* swap the top two stack elements */ + lua_settable(L, -3); +} + + +/** Throw a formatted lua error. + * + * The message will get prefixed by "ERROR: " and supplemented by stack trace. + * \return never! It calls lua_error(). + * + * Example: + ERROR: not a valid pin_sha256: 'a1Z/3ek=', raw length 5 instead of 32 + stack traceback: + [C]: in function 'tls_client' + /PathToPREFIX/lib/kdns_modules/policy.lua:175: in function 'TLS_FORWARD' + /PathToConfig.lua:46: in main chunk + */ +KR_PRINTF(2) KR_NORETURN KR_COLD +void lua_error_p(lua_State *L, const char *fmt, ...); +/** @internal Annotate for static checkers. */ +KR_NORETURN int lua_error(lua_State *L); + +/** Shortcut for common case. */ +static inline void lua_error_maybe(lua_State *L, int err) +{ + if (err) lua_error_p(L, "%s", kr_strerror(err)); +} + +static inline int execute_callback(lua_State *L, int argc) +{ + int ret = engine_pcall(L, argc); + if (ret != 0) { + kr_log_error("error: %s\n", lua_tostring(L, -1)); + } + /* Clear the stack, there may be event a/o enything returned */ + lua_settop(L, 0); + return ret; +} + +/** Push a pointer as heavy/full userdata. + * + * It's useful as a replacement of lua_pushlightuserdata(), + * but note that it behaves differently in lua (converts to pointer-to-pointer). + */ +static inline void lua_pushpointer(lua_State *L, void *p) +{ + void *addr = lua_newuserdata(L, sizeof(void *)); + assert(addr); + memcpy(addr, &p, sizeof(void *)); +} + diff --git a/utils/watcher/bindings/modules.c b/utils/watcher/bindings/modules.c new file mode 100644 index 000000000..4502129ac --- /dev/null +++ b/utils/watcher/bindings/modules.c @@ -0,0 +1,91 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "bindings/impl.h" + + +/** List loaded modules */ +static int mod_list(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + lua_newtable(L); + for (unsigned i = 0; i < engine->modules.len; ++i) { + struct kr_module *module = engine->modules.at[i]; + lua_pushstring(L, module->name); + lua_rawseti(L, -2, i + 1); + } + return 1; +} + +/** Load module. */ +static int mod_load(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n != 1 || !lua_isstring(L, 1)) + lua_error_p(L, "expected 'load(string name)'"); + /* Parse precedence declaration */ + char *declaration = strdup(lua_tostring(L, 1)); + if (!declaration) + return kr_error(ENOMEM); + const char *name = strtok(declaration, " "); + const char *precedence = strtok(NULL, " "); + const char *ref = strtok(NULL, " "); + /* Load engine module */ + struct engine *engine = engine_luaget(L); + int ret = engine_register(engine, name, precedence, ref); + free(declaration); + if (ret != 0) { + if (ret == kr_error(EIDRM)) { + lua_error_p(L, "referenced module not found"); + } else { + lua_error_maybe(L, ret); + } + } + + lua_pushboolean(L, 1); + return 1; +} + +/** Unload module. */ +static int mod_unload(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n != 1 || !lua_isstring(L, 1)) + lua_error_p(L, "expected 'unload(string name)'"); + /* Unload engine module */ + struct engine *engine = engine_luaget(L); + int ret = engine_unregister(engine, lua_tostring(L, 1)); + lua_error_maybe(L, ret); + + lua_pushboolean(L, 1); + return 1; +} + +int kr_bindings_modules(lua_State *L) +{ + static const luaL_Reg lib[] = { + { "list", mod_list }, + { "load", mod_load }, + { "unload", mod_unload }, + { NULL, NULL } + }; + + luaL_register(L, "modules", lib); + return 1; +} + diff --git a/utils/watcher/bindings/net.c b/utils/watcher/bindings/net.c new file mode 100644 index 000000000..360ec2292 --- /dev/null +++ b/utils/watcher/bindings/net.c @@ -0,0 +1,1042 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "bindings/impl.h" + +#include "contrib/base64.h" +#include "network.h" +#include "tls.h" +#include "worker.h" + +#include + +/** Table and next index on top of stack -> append entries for given endpoint_array_t. */ +static int net_list_add(const char *key, void *val, void *ext) +{ + lua_State *L = (lua_State *)ext; + lua_Integer i = lua_tointeger(L, -1); + endpoint_array_t *ep_array = val; + for (int j = 0; j < ep_array->len; ++j) { + struct endpoint *ep = &ep_array->at[j]; + lua_newtable(L); // connection tuple + + if (ep->flags.kind) { + lua_pushstring(L, ep->flags.kind); + } else if (ep->flags.tls) { + lua_pushliteral(L, "tls"); + } else { + lua_pushliteral(L, "dns"); + } + lua_setfield(L, -2, "kind"); + + lua_newtable(L); // "transport" table + + lua_pushboolean(L, ep->flags.freebind); + lua_setfield(L, -2, "freebind"); + + switch (ep->family) { + case AF_INET: + lua_pushliteral(L, "inet4"); + break; + case AF_INET6: + lua_pushliteral(L, "inet6"); + break; + case AF_UNIX: + lua_pushliteral(L, "unix"); + break; + default: + lua_pushliteral(L, "invalid"); + assert(!EINVAL); + } + lua_setfield(L, -2, "family"); + + lua_pushstring(L, key); + if (ep->family != AF_UNIX) { + lua_setfield(L, -2, "ip"); + } else { + lua_setfield(L, -2, "path"); + } + + if (ep->family != AF_UNIX) { + lua_pushinteger(L, ep->port); + lua_setfield(L, -2, "port"); + } + + if (ep->family == AF_UNIX) { + lua_pushliteral(L, "stream"); + } else if (ep->flags.sock_type == SOCK_STREAM) { + lua_pushliteral(L, "tcp"); + } else if (ep->flags.sock_type == SOCK_DGRAM) { + lua_pushliteral(L, "udp"); + } else { + assert(!EINVAL); + lua_pushliteral(L, "invalid"); + } + lua_setfield(L, -2, "protocol"); + + lua_setfield(L, -2, "transport"); + + lua_settable(L, -3); + i++; + lua_pushinteger(L, i); + } + return kr_ok(); +} + +/** List active endpoints. */ +static int net_list(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + lua_newtable(L); + lua_pushinteger(L, 1); + map_walk(&engine->net.endpoints, net_list_add, L); + lua_pop(L, 1); + return 1; +} + +/** Listen on an address list represented by the top of lua stack. + * \note kind ownership is not transferred + * \return success */ +static bool net_listen_addrs(lua_State *L, int port, bool tls, const char *kind, bool freebind) +{ + /* Case: table with 'addr' field; only follow that field directly. */ + lua_getfield(L, -1, "addr"); + if (!lua_isnil(L, -1)) { + lua_replace(L, -2); + } else { + lua_pop(L, 1); + } + + /* Case: string, representing a single address. */ + const char *str = lua_tostring(L, -1); + if (str != NULL) { + struct engine *engine = engine_luaget(L); + int ret = 0; + endpoint_flags_t flags = { .tls = tls, .freebind = freebind }; + if (!kind && !flags.tls) { /* normal UDP */ + flags.sock_type = SOCK_DGRAM; + ret = network_listen(&engine->net, str, port, flags); + } + if (!kind && ret == 0) { /* common for normal TCP and TLS */ + flags.sock_type = SOCK_STREAM; + ret = network_listen(&engine->net, str, port, flags); + } + if (kind) { + flags.kind = strdup(kind); + flags.sock_type = SOCK_STREAM; /* TODO: allow to override this? */ + ret = network_listen(&engine->net, str, port, flags); + } + if (ret != 0) { + const char *stype = flags.sock_type == SOCK_DGRAM ? "UDP" : "TCP"; + kr_log_error("[system] bind to '%s@%d' (%s): %s\n", + str, port, stype, kr_strerror(ret)); + } + return ret == 0; + } + + /* Last case: table where all entries are added recursively. */ + if (!lua_istable(L, -1)) + lua_error_p(L, "bad type for address"); + lua_pushnil(L); + while (lua_next(L, -2)) { + if (!net_listen_addrs(L, port, tls, kind, freebind)) + return false; + lua_pop(L, 1); + } + return true; +} + +static bool table_get_flag(lua_State *L, int index, const char *key, bool def) +{ + bool result = def; + lua_getfield(L, index, key); + if (lua_isboolean(L, -1)) { + result = lua_toboolean(L, -1); + } + lua_pop(L, 1); + return result; +} + +/** Listen on endpoint. */ +static int net_listen(lua_State *L) +{ + /* Check parameters */ + int n = lua_gettop(L); + if (n < 1 || n > 3) { + lua_error_p(L, "expected one to three arguments; usage:\n" + "net.listen(addressses, [port = " STR(KR_DNS_PORT) + ", flags = {tls = (port == " STR(KR_DNS_TLS_PORT) ")}])\n"); + } + + int port = KR_DNS_PORT; + if (n > 1) { + if (lua_isnumber(L, 2)) { + port = lua_tointeger(L, 2); + } else + if (!lua_isnil(L, 2)) { + lua_error_p(L, "wrong type of second parameter (port number)"); + } + } + + bool tls = (port == KR_DNS_TLS_PORT); + bool freebind = false; + const char *kind = NULL; + if (n > 2 && !lua_isnil(L, 3)) { + if (!lua_istable(L, 3)) + lua_error_p(L, "wrong type of third parameter (table expected)"); + tls = table_get_flag(L, 3, "tls", tls); + freebind = table_get_flag(L, 3, "freebind", tls); + + lua_getfield(L, 3, "kind"); + const char *k = lua_tostring(L, -1); + if (k && strcasecmp(k, "dns") == 0) { + tls = false; + } else + if (k && strcasecmp(k, "tls") == 0) { + tls = true; + } else + if (k) { + kind = k; + } + } + + /* Memory management of `kind` string is difficult due to longjmp etc. + * Pop will unreference the lua value, so we store it on C stack instead (!) */ + const int kind_alen = kind ? strlen(kind) + 1 : 1 /* 0 length isn't C standard */; + char kind_buf[kind_alen]; + if (kind) { + memcpy(kind_buf, kind, kind_alen); + kind = kind_buf; + } + + /* Now focus on the first argument. */ + lua_settop(L, 1); + if (!net_listen_addrs(L, port, tls, kind, freebind)) + lua_error_p(L, "net.listen() failed to bind"); + lua_pushboolean(L, true); + return 1; +} + +/** Close endpoint. */ +static int net_close(lua_State *L) +{ + /* Check parameters */ + const int n = lua_gettop(L); + bool ok = (n == 1 || n == 2) && lua_isstring(L, 1); + const char *addr = lua_tostring(L, 1); + int port; + if (ok && (n < 2 || lua_isnil(L, 2))) { + port = -1; + } else if (ok) { + ok = lua_isnumber(L, 2); + port = lua_tointeger(L, 2); + ok = ok && port >= 0 && port <= 65535; + } + if (!ok) + lua_error_p(L, "expected 'close(string addr, [number port])'"); + + struct network *net = &engine_luaget(L)->net; + int ret = network_close(net, addr, port); + lua_pushboolean(L, ret == 0); + return 1; +} + +/** List available interfaces. */ +static int net_interfaces(lua_State *L) +{ + /* Retrieve interface list */ + int count = 0; + char buf[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */ + uv_interface_address_t *info = NULL; + uv_interface_addresses(&info, &count); + lua_newtable(L); + for (int i = 0; i < count; ++i) { + uv_interface_address_t iface = info[i]; + lua_getfield(L, -1, iface.name); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + lua_newtable(L); + } + + /* Address */ + lua_getfield(L, -1, "addr"); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + lua_newtable(L); + } + if (iface.address.address4.sin_family == AF_INET) { + uv_ip4_name(&iface.address.address4, buf, sizeof(buf)); + } else if (iface.address.address4.sin_family == AF_INET6) { + uv_ip6_name(&iface.address.address6, buf, sizeof(buf)); + } else { + buf[0] = '\0'; + } + lua_pushstring(L, buf); + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + lua_setfield(L, -2, "addr"); + + /* Hardware address. */ + char *p = buf; + for (int k = 0; k < sizeof(iface.phys_addr); ++k) { + sprintf(p, "%.2x:", (uint8_t)iface.phys_addr[k]); + p += 3; + } + p[-1] = '\0'; + lua_pushstring(L, buf); + lua_setfield(L, -2, "mac"); + + /* Push table */ + lua_setfield(L, -2, iface.name); + } + uv_free_interface_addresses(info, count); + + return 1; +} + +/** Set UDP maximum payload size. */ +static int net_bufsize(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + knot_rrset_t *opt_rr = engine->resolver.opt_rr; + if (!lua_isnumber(L, 1)) { + lua_pushinteger(L, knot_edns_get_payload(opt_rr)); + return 1; + } + int bufsize = lua_tointeger(L, 1); + if (bufsize < 512 || bufsize > UINT16_MAX) + lua_error_p(L, "bufsize must be within <512, " STR(UINT16_MAX) ">"); + knot_edns_set_payload(opt_rr, (uint16_t) bufsize); + return 0; +} + +/** Set TCP pipelining size. */ +static int net_pipeline(lua_State *L) +{ + struct worker_ctx *worker = the_worker; + if (!worker) { + return 0; + } + if (!lua_isnumber(L, 1)) { + lua_pushinteger(L, worker->tcp_pipeline_max); + return 1; + } + int len = lua_tointeger(L, 1); + if (len < 0 || len > UINT16_MAX) + lua_error_p(L, "tcp_pipeline must be within <0, " STR(UINT16_MAX) ">"); + worker->tcp_pipeline_max = len; + lua_pushinteger(L, len); + return 1; +} + +static int net_tls(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + if (!engine) { + return 0; + } + struct network *net = &engine->net; + if (!net) { + return 0; + } + + /* Only return current credentials. */ + if (lua_gettop(L) == 0) { + /* No credentials configured yet. */ + if (!net->tls_credentials) { + return 0; + } + lua_newtable(L); + lua_pushstring(L, net->tls_credentials->tls_cert); + lua_setfield(L, -2, "cert_file"); + lua_pushstring(L, net->tls_credentials->tls_key); + lua_setfield(L, -2, "key_file"); + return 1; + } + + if ((lua_gettop(L) != 2) || !lua_isstring(L, 1) || !lua_isstring(L, 2)) + lua_error_p(L, "net.tls takes two parameters: (\"cert_file\", \"key_file\")"); + + int r = tls_certificate_set(net, lua_tostring(L, 1), lua_tostring(L, 2)); + lua_error_maybe(L, r); + + lua_pushboolean(L, true); + return 1; +} + +/** Return a lua table with TLS authentication parameters. + * The format is the same as passed to policy.TLS_FORWARD(); + * more precisely, it's in a compatible canonical form. */ +static int tls_params2lua(lua_State *L, trie_t *params) +{ + lua_newtable(L); + if (!params) /* Allowed special case. */ + return 1; + trie_it_t *it; + size_t list_index = 0; + for (it = trie_it_begin(params); !trie_it_finished(it); trie_it_next(it)) { + /* Prepare table for the current address + * and its index in the returned list. */ + lua_pushinteger(L, ++list_index); + lua_createtable(L, 0, 2); + + /* Get the "addr#port" string... */ + size_t ia_len; + const char *key = trie_it_key(it, &ia_len); + int af = AF_UNSPEC; + if (ia_len == 2 + sizeof(struct in_addr)) { + af = AF_INET; + } else if (ia_len == 2 + sizeof(struct in6_addr)) { + af = AF_INET6; + } + if (!key || af == AF_UNSPEC) { + assert(false); + lua_error_p(L, "internal error: bad IP address"); + } + uint16_t port; + memcpy(&port, key, sizeof(port)); + port = ntohs(port); + const char *ia = key + sizeof(port); + char str[INET6_ADDRSTRLEN + 1 + 5 + 1]; + size_t len = sizeof(str); + if (kr_ntop_str(af, ia, port, str, &len) != kr_ok()) { + assert(false); + lua_error_p(L, "internal error: bad IP address conversion"); + } + /* ...and push it as [1]. */ + lua_pushinteger(L, 1); + lua_pushlstring(L, str, len - 1 /* len includes '\0' */); + lua_settable(L, -3); + + const tls_client_param_t *e = *trie_it_val(it); + if (!e) + lua_error_p(L, "internal problem - NULL entry for %s", str); + + /* .hostname = */ + if (e->hostname) { + lua_pushstring(L, e->hostname); + lua_setfield(L, -2, "hostname"); + } + + /* .ca_files = */ + if (e->ca_files.len) { + lua_createtable(L, e->ca_files.len, 0); + for (size_t i = 0; i < e->ca_files.len; ++i) { + lua_pushinteger(L, i + 1); + lua_pushstring(L, e->ca_files.at[i]); + lua_settable(L, -3); + } + lua_setfield(L, -2, "ca_files"); + } + + /* .pin_sha256 = ... ; keep sane indentation via goto. */ + if (!e->pins.len) goto no_pins; + lua_createtable(L, e->pins.len, 0); + for (size_t i = 0; i < e->pins.len; ++i) { + uint8_t pin_base64[TLS_SHA256_BASE64_BUFLEN]; + int err = base64_encode(e->pins.at[i], TLS_SHA256_RAW_LEN, + pin_base64, sizeof(pin_base64)); + if (err < 0) { + assert(false); + lua_error_p(L, + "internal problem when converting pin_sha256: %s", + kr_strerror(err)); + } + lua_pushinteger(L, i + 1); + lua_pushlstring(L, (const char *)pin_base64, err); + /* pin_base64 isn't 0-terminated ^^^ */ + lua_settable(L, -3); + } + lua_setfield(L, -2, "pin_sha256"); + + no_pins:/* .insecure = */ + if (e->insecure) { + lua_pushboolean(L, true); + lua_setfield(L, -2, "insecure"); + } + /* Now the whole table is pushed atop the returned list. */ + lua_settable(L, -3); + } + trie_it_free(it); + return 1; +} + +static inline int cmp_sha256(const void *p1, const void *p2) +{ + return memcmp(*(char * const *)p1, *(char * const *)p2, TLS_SHA256_RAW_LEN); +} +static int net_tls_client(lua_State *L) +{ + /* TODO idea: allow starting the lua table with *multiple* IP targets, + * meaning the authentication config should be applied to each. + */ + struct network *net = &engine_luaget(L)->net; + if (lua_gettop(L) == 0) + return tls_params2lua(L, net->tls_client_params); + /* Various basic sanity-checking. */ + if (lua_gettop(L) != 1 || !lua_istable(L, 1)) + lua_error_maybe(L, EINVAL); + /* check that only allowed keys are present */ + { + const char *bad_key = lua_table_checkindices(L, (const char *[]) + { "1", "hostname", "ca_file", "pin_sha256", "insecure", NULL }); + if (bad_key) + lua_error_p(L, "found unexpected key '%s'", bad_key); + } + + /**** Phase 1: get the parameter into a C struct, incl. parse of CA files, + * regardless of the address-pair having an entry already. */ + + tls_client_param_t *newcfg = tls_client_param_new(); + if (!newcfg) + lua_error_p(L, "out of memory or something like that :-/"); + /* Shortcut for cleanup actions needed from now on. */ + #define ERROR(...) do { \ + free(newcfg); \ + lua_error_p(L, __VA_ARGS__); \ + } while (false) + + /* .hostname - always accepted. */ + lua_getfield(L, 1, "hostname"); + if (!lua_isnil(L, -1)) { + const char *hn_str = lua_tostring(L, -1); + /* Convert to lower-case dname and back, for checking etc. */ + knot_dname_t dname[KNOT_DNAME_MAXLEN]; + if (!hn_str || !knot_dname_from_str(dname, hn_str, sizeof(dname))) + ERROR("invalid hostname"); + knot_dname_to_lower(dname); + char *h = knot_dname_to_str_alloc(dname); + if (!h) + ERROR("%s", kr_strerror(ENOMEM)); + /* Strip the final dot produced by knot_dname_*() */ + h[strlen(h) - 1] = '\0'; + newcfg->hostname = h; + } + lua_pop(L, 1); + + /* .ca_file - it can be a list of paths, contrary to the name. */ + bool has_ca_file = false; + lua_getfield(L, 1, "ca_file"); + if (!lua_isnil(L, -1)) { + if (!newcfg->hostname) + ERROR("missing hostname but specifying ca_file"); + lua_listify(L); + array_init(newcfg->ca_files); /*< placate apparently confused scan-build */ + if (array_reserve(newcfg->ca_files, lua_objlen(L, -1)) != 0) /*< optim. */ + ERROR("%s", kr_strerror(ENOMEM)); + /* Iterate over table at the top of the stack. + * http://www.lua.org/manual/5.1/manual.html#lua_next */ + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + has_ca_file = true; /* deferred here so that {} -> false */ + const char *ca_file = lua_tostring(L, -1); + if (!ca_file) + ERROR("ca_file contains a non-string"); + /* Let gnutls process it immediately, so garbage gets detected. */ + int ret = gnutls_certificate_set_x509_trust_file( + newcfg->credentials, ca_file, GNUTLS_X509_FMT_PEM); + if (ret < 0) { + ERROR("failed to import certificate file '%s': %s - %s\n", + ca_file, gnutls_strerror_name(ret), + gnutls_strerror(ret)); + } else { + kr_log_verbose( + "[tls_client] imported %d certs from file '%s'\n", + ret, ca_file); + } + + ca_file = strdup(ca_file); + if (!ca_file || array_push(newcfg->ca_files, ca_file) < 0) + ERROR("%s", kr_strerror(ENOMEM)); + } + /* Sort the strings for easier comparison later. */ + if (newcfg->ca_files.len) { + qsort(&newcfg->ca_files.at[0], newcfg->ca_files.len, + sizeof(newcfg->ca_files.at[0]), strcmp_p); + } + } + lua_pop(L, 1); + + /* .pin_sha256 */ + lua_getfield(L, 1, "pin_sha256"); + if (!lua_isnil(L, -1)) { + if (has_ca_file) + ERROR("mixing pin_sha256 with ca_file is not supported"); + lua_listify(L); + array_init(newcfg->pins); /*< placate apparently confused scan-build */ + if (array_reserve(newcfg->pins, lua_objlen(L, -1)) != 0) /*< optim. */ + ERROR("%s", kr_strerror(ENOMEM)); + /* Iterate over table at the top of the stack. */ + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + const char *pin = lua_tostring(L, -1); + if (!pin) + ERROR("pin_sha256 is not a string"); + uint8_t *pin_raw = malloc(TLS_SHA256_RAW_LEN); + /* Push the string early to simplify error processing. */ + if (!pin_raw || array_push(newcfg->pins, pin_raw) < 0) { + assert(false); + free(pin_raw); + ERROR("%s", kr_strerror(ENOMEM)); + } + int ret = base64_decode((const uint8_t *)pin, strlen(pin), + pin_raw, TLS_SHA256_RAW_LEN + 8); + if (ret < 0) { + ERROR("not a valid pin_sha256: '%s' (length %d), %s\n", + pin, (int)strlen(pin), knot_strerror(ret)); + } else if (ret != TLS_SHA256_RAW_LEN) { + ERROR("not a valid pin_sha256: '%s', " + "raw length %d instead of " + STR(TLS_SHA256_RAW_LEN)"\n", + pin, ret); + } + } + /* Sort the raw strings for easier comparison later. */ + if (newcfg->pins.len) { + qsort(&newcfg->pins.at[0], newcfg->pins.len, + sizeof(newcfg->pins.at[0]), cmp_sha256); + } + } + lua_pop(L, 1); + + /* .insecure */ + lua_getfield(L, 1, "insecure"); + if (lua_isnil(L, -1)) { + if (!newcfg->hostname && !newcfg->pins.len) + ERROR("no way to authenticate and not set as insecure"); + } else if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) { + newcfg->insecure = true; + if (has_ca_file || newcfg->pins.len) + ERROR("set as insecure but provided authentication config"); + } else { + ERROR("incorrect value in the 'insecure' field"); + } + lua_pop(L, 1); + + /* Init CAs from system trust store, if needed. */ + if (!newcfg->insecure && !newcfg->pins.len && !has_ca_file) { + int ret = gnutls_certificate_set_x509_system_trust(newcfg->credentials); + if (ret <= 0) { + ERROR("failed to use system CA certificate store: %s", + ret ? gnutls_strerror(ret) : kr_strerror(ENOENT)); + } else { + kr_log_verbose( + "[tls_client] imported %d certs from system store\n", + ret); + } + } + #undef ERROR + + /**** Phase 2: deal with the C authentication "table". */ + /* Parse address and port. */ + lua_pushinteger(L, 1); + lua_gettable(L, 1); + const char *addr_str = lua_tostring(L, -1); + if (!addr_str) + lua_error_p(L, "address is not a string"); + char buf[INET6_ADDRSTRLEN + 1]; + uint16_t port = 853; + const struct sockaddr *addr = NULL; + if (kr_straddr_split(addr_str, buf, &port) == kr_ok()) + addr = kr_straddr_socket(buf, port, NULL); + /* Add newcfg into the C map, saving the original into oldcfg. */ + if (!addr) + lua_error_p(L, "address '%s' could not be converted", addr_str); + tls_client_param_t **oldcfgp = tls_client_param_getptr( + &net->tls_client_params, addr, true); + free_const(addr); + if (!oldcfgp) + lua_error_p(L, "internal error when extending tls_client_params map"); + tls_client_param_t *oldcfg = *oldcfgp; + *oldcfgp = newcfg; /* replace old config in trie with the new one */ + /* If there was no original entry, it's easy! */ + if (!oldcfg) + return 0; + + /* Check for equality (newcfg vs. oldcfg), and print a warning if not equal.*/ + const bool ok_h = (!newcfg->hostname && !oldcfg->hostname) + || (newcfg->hostname && oldcfg->hostname && strcmp(newcfg->hostname, oldcfg->hostname) == 0); + bool ok_ca = newcfg->ca_files.len == oldcfg->ca_files.len; + for (int i = 0; ok_ca && i < newcfg->ca_files.len; ++i) + ok_ca = strcmp(newcfg->ca_files.at[i], oldcfg->ca_files.at[i]) == 0; + bool ok_pins = newcfg->pins.len == oldcfg->pins.len; + for (int i = 0; ok_pins && i < newcfg->pins.len; ++i) + ok_ca = memcmp(newcfg->pins.at[i], oldcfg->pins.at[i], TLS_SHA256_RAW_LEN) == 0; + const bool ok_insecure = newcfg->insecure == oldcfg->insecure; + if (!(ok_h && ok_ca && ok_pins && ok_insecure)) { + kr_log_info("[tls_client] " + "warning: re-defining TLS authentication parameters for %s\n", + addr_str); + } + tls_client_param_unref(oldcfg); + return 0; +} + +int net_tls_client_clear(lua_State *L) +{ + /* One parameter: address -> convert it to a struct sockaddr. */ + if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) + lua_error_p(L, "net.tls_client_clear() requires one parameter (\"address\")"); + const char *addr_str = lua_tostring(L, 1); + char buf[INET6_ADDRSTRLEN + 1]; + uint16_t port = 853; + const struct sockaddr *addr = NULL; + if (kr_straddr_split(addr_str, buf, &port) == kr_ok()) + addr = kr_straddr_socket(buf, port, NULL); + if (!addr) + lua_error_p(L, "invalid IP address"); + /* Do the actual removal. */ + struct network *net = &engine_luaget(L)->net; + int r = tls_client_param_remove(net->tls_client_params, addr); + free_const(addr); + lua_error_maybe(L, r); + lua_pushboolean(L, true); + return 1; +} + +static int net_tls_padding(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + + /* Only return current padding. */ + if (lua_gettop(L) == 0) { + if (engine->resolver.tls_padding < 0) { + lua_pushboolean(L, true); + return 1; + } else if (engine->resolver.tls_padding == 0) { + lua_pushboolean(L, false); + return 1; + } + lua_pushinteger(L, engine->resolver.tls_padding); + return 1; + } + + const char *errstr = "net.tls_padding parameter has to be true, false," + " or a number between <0, " STR(MAX_TLS_PADDING) ">"; + if (lua_gettop(L) != 1) + lua_error_p(L, "%s", errstr); + if (lua_isboolean(L, 1)) { + bool x = lua_toboolean(L, 1); + if (x) { + engine->resolver.tls_padding = -1; + } else { + engine->resolver.tls_padding = 0; + } + } else if (lua_isnumber(L, 1)) { + int padding = lua_tointeger(L, 1); + if ((padding < 0) || (padding > MAX_TLS_PADDING)) + lua_error_p(L, "%s", errstr); + engine->resolver.tls_padding = padding; + } else { + lua_error_p(L, "%s", errstr); + } + lua_pushboolean(L, true); + return 1; +} + +/** Shorter salt can't contain much entropy. */ +#define net_tls_sticket_MIN_SECRET_LEN 32 + +static int net_tls_sticket_secret_string(lua_State *L) +{ + struct network *net = &engine_luaget(L)->net; + + size_t secret_len; + const char *secret; + + if (lua_gettop(L) == 0) { + /* Zero-length secret, implying random key. */ + secret_len = 0; + secret = NULL; + } else { + if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) { + lua_error_p(L, + "net.tls_sticket_secret takes one parameter: (\"secret string\")"); + } + secret = lua_tolstring(L, 1, &secret_len); + if (secret_len < net_tls_sticket_MIN_SECRET_LEN || !secret) { + lua_error_p(L, "net.tls_sticket_secret - the secret is shorter than " + STR(net_tls_sticket_MIN_SECRET_LEN) " bytes"); + } + } + + tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx); + net->tls_session_ticket_ctx = + tls_session_ticket_ctx_create(net->loop, secret, secret_len); + if (net->tls_session_ticket_ctx == NULL) { + lua_error_p(L, + "net.tls_sticket_secret_string - can't create session ticket context"); + } + + lua_pushboolean(L, true); + return 1; +} + +static int net_tls_sticket_secret_file(lua_State *L) +{ + if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) { + lua_error_p(L, + "net.tls_sticket_secret_file takes one parameter: (\"file name\")"); + } + + const char *file_name = lua_tostring(L, 1); + if (strlen(file_name) == 0) + lua_error_p(L, "net.tls_sticket_secret_file - empty file name"); + + FILE *fp = fopen(file_name, "r"); + if (fp == NULL) { + lua_error_p(L, "net.tls_sticket_secret_file - can't open file '%s': %s", + file_name, strerror(errno)); + } + + char secret_buf[TLS_SESSION_TICKET_SECRET_MAX_LEN]; + const size_t secret_len = fread(secret_buf, 1, sizeof(secret_buf), fp); + int err = ferror(fp); + if (err) { + lua_error_p(L, + "net.tls_sticket_secret_file - error reading from file '%s': %s", + file_name, strerror(err)); + } + if (secret_len < net_tls_sticket_MIN_SECRET_LEN) { + lua_error_p(L, + "net.tls_sticket_secret_file - file '%s' is shorter than " + STR(net_tls_sticket_MIN_SECRET_LEN) " bytes", + file_name); + } + fclose(fp); + + struct network *net = &engine_luaget(L)->net; + + tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx); + net->tls_session_ticket_ctx = + tls_session_ticket_ctx_create(net->loop, secret_buf, secret_len); + if (net->tls_session_ticket_ctx == NULL) { + lua_error_p(L, + "net.tls_sticket_secret_file - can't create session ticket context"); + } + lua_pushboolean(L, true); + return 1; +} + +static int net_outgoing(lua_State *L, int family) +{ + union inaddr *addr; + if (family == AF_INET) + addr = (union inaddr*)&the_worker->out_addr4; + else + addr = (union inaddr*)&the_worker->out_addr6; + + if (lua_gettop(L) == 0) { /* Return the current value. */ + if (addr->ip.sa_family == AF_UNSPEC) { + lua_pushnil(L); + return 1; + } + if (addr->ip.sa_family != family) { + assert(false); + lua_error_p(L, "bad address family"); + } + char addr_buf[INET6_ADDRSTRLEN]; + int err; + if (family == AF_INET) + err = uv_ip4_name(&addr->ip4, addr_buf, sizeof(addr_buf)); + else + err = uv_ip6_name(&addr->ip6, addr_buf, sizeof(addr_buf)); + lua_error_maybe(L, err); + lua_pushstring(L, addr_buf); + return 1; + } + + if ((lua_gettop(L) != 1) || (!lua_isstring(L, 1) && !lua_isnil(L, 1))) + lua_error_p(L, "net.outgoing_vX takes one address string parameter or nil"); + + if (lua_isnil(L, 1)) { + addr->ip.sa_family = AF_UNSPEC; + return 1; + } + + const char *addr_str = lua_tostring(L, 1); + int err; + if (family == AF_INET) + err = uv_ip4_addr(addr_str, 0, &addr->ip4); + else + err = uv_ip6_addr(addr_str, 0, &addr->ip6); + if (err) + lua_error_p(L, "net.outgoing_vX: failed to parse the address"); + lua_pushboolean(L, true); + return 1; +} + +static int net_outgoing_v4(lua_State *L) { return net_outgoing(L, AF_INET); } +static int net_outgoing_v6(lua_State *L) { return net_outgoing(L, AF_INET6); } + +static int net_update_timeout(lua_State *L, uint64_t *timeout, const char *name) +{ + /* Only return current idle timeout. */ + if (lua_gettop(L) == 0) { + lua_pushinteger(L, *timeout); + return 1; + } + + if ((lua_gettop(L) != 1)) + lua_error_p(L, "%s takes one parameter: (\"idle timeout\")", name); + + if (lua_isnumber(L, 1)) { + int idle_timeout = lua_tointeger(L, 1); + if (idle_timeout <= 0) + lua_error_p(L, "%s parameter has to be positive number", name); + *timeout = idle_timeout; + } else { + lua_error_p(L, "%s parameter has to be positive number", name); + } + lua_pushboolean(L, true); + return 1; +} + +static int net_tcp_in_idle(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + struct network *net = &engine->net; + + return net_update_timeout(L, &net->tcp.in_idle_timeout, "net.tcp_in_idle"); +} + +static int net_tls_handshake_timeout(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + struct network *net = &engine->net; + + return net_update_timeout(L, &net->tcp.tls_handshake_timeout, "net.tls_handshake_timeout"); +} + +static int net_bpf_set(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + + if (lua_gettop(L) != 1 || !lua_isnumber(L, 1)) { + lua_error_p(L, "net.bpf_set(fd) takes one parameter:" + " the open file descriptor of a loaded BPF program"); + } + +#if __linux__ + + struct network *net = &engine->net; + int progfd = lua_tointeger(L, 1); + if (progfd == 0) { + /* conversion error despite that fact + * that lua_isnumber(L, 1) has returned true. + * Real or stdin? */ + lua_error_p(L, "failed to convert parameter"); + } + lua_pop(L, 1); + + if (network_set_bpf(net, progfd) == 0) { + lua_error_p(L, "failed to attach BPF program to some networks: %s", + kr_strerror(errno)); + } + + lua_pushboolean(L, 1); + return 1; + +#endif + lua_error_p(L, "BPF is not supported on this operating system"); +} + +static int net_bpf_clear(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + + if (lua_gettop(L) != 0) + lua_error_p(L, "net.bpf_clear() does not take any parameters"); + +#if __linux__ + + struct network *net = &engine->net; + network_clear_bpf(net); + + lua_pushboolean(L, 1); + return 1; + +#endif + lua_error_p(L, "BPF is not supported on this operating system"); +} + +static int net_register_endpoint_kind(lua_State *L) +{ + const int param_count = lua_gettop(L); + if (param_count != 1 && param_count != 2) + lua_error_p(L, "expected one or two parameters"); + if (!lua_isstring(L, 1)) { + lua_error_p(L, "incorrect kind '%s'", lua_tostring(L, 1)); + } + size_t kind_len; + const char *kind = lua_tolstring(L, 1, &kind_len); + struct network *net = &engine_luaget(L)->net; + + /* Unregistering */ + if (param_count == 1) { + void *val; + if (trie_del(net->endpoint_kinds, kind, kind_len, &val) == KNOT_EOK) { + const int fun_id = (char *)val - (char *)NULL; + luaL_unref(L, LUA_REGISTRYINDEX, fun_id); + return 0; + } + lua_error_p(L, "attempt to unregister unknown kind '%s'\n", kind); + } /* else */ + + /* Registering */ + assert(param_count == 2); + if (!lua_isfunction(L, 2)) { + lua_error_p(L, "second parameter: expected function but got %s\n", + lua_typename(L, lua_type(L, 2))); + } + const int fun_id = luaL_ref(L, LUA_REGISTRYINDEX); + /* ^^ The function is on top of the stack, incidentally. */ + void **pp = trie_get_ins(net->endpoint_kinds, kind, kind_len); + if (!pp) lua_error_maybe(L, kr_error(ENOMEM)); + if (*pp != NULL || !strcasecmp(kind, "dns") || !strcasecmp(kind, "tls")) + lua_error_p(L, "attempt to register known kind '%s'\n", kind); + *pp = (char *)NULL + fun_id; + /* We don't attempt to engage correspoinding endpoints now. + * That's the job for network_engage_endpoints() later. */ + return 0; +} + +int kr_bindings_net(lua_State *L) +{ + static const luaL_Reg lib[] = { + { "list", net_list }, + { "listen", net_listen }, + { "close", net_close }, + { "interfaces", net_interfaces }, + { "bufsize", net_bufsize }, + { "tcp_pipeline", net_pipeline }, + { "tls", net_tls }, + { "tls_server", net_tls }, + { "tls_client", net_tls_client }, + { "tls_client_clear", net_tls_client_clear }, + { "tls_padding", net_tls_padding }, + { "tls_sticket_secret", net_tls_sticket_secret_string }, + { "tls_sticket_secret_file", net_tls_sticket_secret_file }, + { "outgoing_v4", net_outgoing_v4 }, + { "outgoing_v6", net_outgoing_v6 }, + { "tcp_in_idle", net_tcp_in_idle }, + { "tls_handshake_timeout", net_tls_handshake_timeout }, + { "bpf_set", net_bpf_set }, + { "bpf_clear", net_bpf_clear }, + { "register_endpoint_kind", net_register_endpoint_kind }, + { NULL, NULL } + }; + luaL_register(L, "net", lib); + return 1; +} + diff --git a/utils/watcher/bindings/watcher.c b/utils/watcher/bindings/watcher.c new file mode 100644 index 000000000..d2252f3b7 --- /dev/null +++ b/utils/watcher/bindings/watcher.c @@ -0,0 +1,74 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "bindings/impl.h" + +#include "contrib/base64.h" +#include "watcher.h" + +#include + + + +static bool table_get_flag(lua_State *L, int index, const char *key, bool def) +{ + bool result = def; + lua_getfield(L, index, key); + if (lua_isboolean(L, -1)) { + result = lua_toboolean(L, -1); + } + lua_pop(L, 1); + return result; +} + +static int watcher_server(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + if (!engine) { + return 0; + } + struct watcher_context *watcher = &engine->watcher; + if (!watcher) { + return 0; + } + + /* Only return current credentials. */ + if (lua_gettop(L) == 0) { + lua_newtable(L); + lua_pushstring(L, watcher.config.auto_start); + lua_setfield(L, -2, "auto_start"); + lua_pushstring(L, watcher.config.auto_cache_gc); + lua_setfield(L, -2, "auto_cache_gc"); + lua_pushnumber(L, watcher.config.kresd_instances); + lua_setfield(L, -2, "kresd_instances"); + return 1; + } + + lua_pushboolean(L, true); + return 1; +} + +int kr_bindings_watcher(lua_State *L) +{ + static const luaL_Reg lib[] = { + { "server", watcher_server }, + + { NULL, NULL } + }; + luaL_register(L, "watcher", lib); + return 1; +} + diff --git a/utils/watcher/bindings/worker.c b/utils/watcher/bindings/worker.c new file mode 100644 index 000000000..825ad96ee --- /dev/null +++ b/utils/watcher/bindings/worker.c @@ -0,0 +1,86 @@ +/* Copyright (C) 2015-2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "bindings/impl.h" + +#include "worker.h" + +static inline double getseconds(uv_timeval_t *tv) +{ + return (double)tv->tv_sec + 0.000001*((double)tv->tv_usec); +} + +/** Return worker statistics. */ +static int wrk_stats(lua_State *L) +{ + struct worker_ctx *worker = the_worker; + if (!worker) { + return 0; + } + lua_newtable(L); + lua_pushnumber(L, worker->stats.queries); + lua_setfield(L, -2, "queries"); + lua_pushnumber(L, worker->stats.concurrent); + lua_setfield(L, -2, "concurrent"); + lua_pushnumber(L, worker->stats.dropped); + lua_setfield(L, -2, "dropped"); + + lua_pushnumber(L, worker->stats.timeout); + lua_setfield(L, -2, "timeout"); + lua_pushnumber(L, worker->stats.udp); + lua_setfield(L, -2, "udp"); + lua_pushnumber(L, worker->stats.tcp); + lua_setfield(L, -2, "tcp"); + lua_pushnumber(L, worker->stats.tls); + lua_setfield(L, -2, "tls"); + lua_pushnumber(L, worker->stats.ipv4); + lua_setfield(L, -2, "ipv4"); + lua_pushnumber(L, worker->stats.ipv6); + lua_setfield(L, -2, "ipv6"); + + /* Add subset of rusage that represents counters. */ + uv_rusage_t rusage; + if (uv_getrusage(&rusage) == 0) { + lua_pushnumber(L, getseconds(&rusage.ru_utime)); + lua_setfield(L, -2, "usertime"); + lua_pushnumber(L, getseconds(&rusage.ru_stime)); + lua_setfield(L, -2, "systime"); + lua_pushnumber(L, rusage.ru_majflt); + lua_setfield(L, -2, "pagefaults"); + lua_pushnumber(L, rusage.ru_nswap); + lua_setfield(L, -2, "swaps"); + lua_pushnumber(L, rusage.ru_nvcsw + rusage.ru_nivcsw); + lua_setfield(L, -2, "csw"); + } + /* Get RSS */ + size_t rss = 0; + if (uv_resident_set_memory(&rss) == 0) { + lua_pushnumber(L, rss); + lua_setfield(L, -2, "rss"); + } + return 1; +} + +int kr_bindings_worker(lua_State *L) +{ + static const luaL_Reg lib[] = { + { "stats", wrk_stats }, + { NULL, NULL } + }; + luaL_register(L, "worker", lib); + return 1; +} + diff --git a/utils/watcher/dbus_control.c b/utils/watcher/dbus_control.c new file mode 100644 index 000000000..1b01e57ce --- /dev/null +++ b/utils/watcher/dbus_control.c @@ -0,0 +1,119 @@ + +#include "lib/utils.h" +#include "systemd/sd-bus.h" + +#include "dbus_control.h" + +#define DBUS_SD_NAME "org.freedesktop.systemd1" +#define DBUS_SD_PATH "/org/freedesktop/systemd1" + +#define DBUS_INT_MAN "org.freedesktop.systemd1.Manager" +#define DBUS_INT_PROP "org.freedesktop.DBus.Properties" +#define DBUS_INT_UNIT "org.freedesktop.systemd1.Unit" + +#define SERVICE_KRESD "kresd@%s.service" +#define SERVICE_CACHE_GC "kres-cache-gc.service" + +#define DBUS_PATH_KRESD "/org/freedesktop/systemd1/unit/kresd_40%s_2eservice" +#define DBUS_PATH_GC "/org/freedesktop/systemd1/unit/kres_2dcache_2dgc_2eservice" + + +int kresd_get_status(const char *instance, char **status) +{ + sd_bus* bus = NULL; + void* userdata = NULL; + char *instance_service; + sd_bus_error err = SD_BUS_ERROR_NULL; + + int ret = sd_bus_default_system(&bus); + + asprintf(&instance_service, SERVICE_KRESD, instance); + + ret = sd_bus_get_property_string( + bus, + DBUS_SD_NAME, + DBUS_PATH_GC, + DBUS_INT_UNIT, + "ActiveState", + &err, + status + ); + free(instance_service); + return ret; +} + +int cache_gc_get_status(char **status) +{ + sd_bus* bus = NULL; + void* userdata = NULL; + sd_bus_error err = SD_BUS_ERROR_NULL; + + int ret = sd_bus_default_system(&bus); + + ret = sd_bus_get_property_string( + bus, + DBUS_SD_NAME, + DBUS_PATH_GC, + DBUS_INT_UNIT, + "ActiveState", + &err, + status + ); + return ret; +} + +int kresd_ctl(const char *method, const char *instance) +{ + sd_bus* bus = NULL; + char *instance_service; + sd_bus_message *reply = NULL; + sd_bus_error error = SD_BUS_ERROR_NULL; + + int ret = sd_bus_default_system(&bus); + + asprintf(&instance_service, SERVICE_KRESD, instance); + + ret = sd_bus_call_method(bus, DBUS_SD_NAME, DBUS_SD_PATH, DBUS_INT_MAN, + method, &error, &reply, "ss", instance_service, "replace"); + if (ret < 0) { + kr_log_error( + "[sdbus] failed to issue method call '%s' %s: %s\n", + method, instance_service, error.message); + goto cleanup; + } + + kr_log_info("[sdbus] %s %s\n", method, instance_service); + + cleanup: + free(instance_service); + sd_bus_error_free(&error); + sd_bus_message_unref(reply); + + return ret; +} + +int cache_gc_ctl(const char *method) +{ + sd_bus* bus = NULL; + sd_bus_message *reply = NULL; + sd_bus_error error = SD_BUS_ERROR_NULL; + + int ret = sd_bus_default_system(&bus); + + ret = sd_bus_call_method(bus, DBUS_SD_NAME, DBUS_SD_PATH, DBUS_INT_MAN, + method, &error, &reply, "ss", SERVICE_CACHE_GC, "replace"); + if (ret < 0) { + kr_log_error( + "[sdbus] failed to issue method call '%s' %s: %s\n", + method, SERVICE_CACHE_GC, error.message); + goto cleanup; + } + + kr_log_info("[sdbus] %s %s\n", method, SERVICE_CACHE_GC); + + cleanup: + sd_bus_error_free(&error); + sd_bus_message_unref(reply); + + return ret; +} \ No newline at end of file diff --git a/utils/watcher/dbus_control.h b/utils/watcher/dbus_control.h new file mode 100644 index 000000000..0ca6a78c4 --- /dev/null +++ b/utils/watcher/dbus_control.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#define UNIT_START "StartUnit" +#define UNIT_STOP "StopUnit" +#define UNIT_RESTART "RestartUnit" + + +typedef struct sdbus_ctx sdbus_ctx_t; + +typedef void (*sdbus_cb)(sdbus_ctx_t *sdbus_ctx, int status); + +/** Context for sdbus. +* might add some other fields in future */ +struct sdbus_uv_ctx { + sd_bus *bus; + sdbus_cb callback; + uv_poll_t uv_handle; +}; + +int kresd_get_status(const char *instance, char **status); + +int cache_gc_get_status(char **status); + +int kresd_ctl(const char *method, const char *instance); + +int cache_gc_ctl(const char *method); diff --git a/utils/watcher/engine.c b/utils/watcher/engine.c new file mode 100644 index 000000000..8d656bcac --- /dev/null +++ b/utils/watcher/engine.c @@ -0,0 +1,876 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include "bindings/impl.h" + +#include "kresconfig.h" +#include "engine.h" +#include "ffimodule.h" +#include "lib/nsrep.h" +#include "lib/cache/api.h" +#include "lib/defines.h" +#include "lib/cache/cdb_lmdb.h" +#include "lib/dnssec/ta.h" + +#include "watcher.h" + +/* Magic defaults for the engine. */ +#ifndef LRU_RTT_SIZE +#define LRU_RTT_SIZE 65536 /**< NS RTT cache size */ +#endif +#ifndef LRU_REP_SIZE +#define LRU_REP_SIZE (LRU_RTT_SIZE / 4) /**< NS reputation cache size */ +#endif +#ifndef LRU_COOKIES_SIZE + #ifdef ENABLE_COOKIES + #define LRU_COOKIES_SIZE LRU_RTT_SIZE /**< DNS cookies cache size. */ + #else + #define LRU_COOKIES_SIZE LRU_ASSOC /* simpler than guards everywhere */ + #endif +#endif + +/**@internal Maximum number of incomplete TCP connections in queue. +* Default is from Redis and Apache. */ +#ifndef TCP_BACKLOG_DEFAULT +#define TCP_BACKLOG_DEFAULT 511 +#endif + +/* Cleanup engine state every 5 minutes */ +const size_t CLEANUP_TIMER = 5*60*1000; + +/* Execute byte code */ +#define l_dobytecode(L, arr, len, name) \ + (luaL_loadbuffer((L), (arr), (len), (name)) || lua_pcall((L), 0, LUA_MULTRET, 0)) + +/* + * Global bindings. + */ + + +/** Print help and available commands. */ +static int l_help(lua_State *L) +{ + static const char *help_str = + "help()\n show this help\n" + "quit()\n quit\n" + "hostname()\n hostname\n" + "package_version()\n return package version\n" + "user(name[, group])\n change process user (and group)\n" + "verbose(true|false)\n toggle verbose mode\n" + "option(opt[, new_val])\n get/set server option\n" + "mode(strict|normal|permissive)\n set resolver strictness level\n" + "reorder_RR([true|false])\n set/get reordering of RRs within RRsets\n" + "resolve(name, type[, class, flags, callback])\n resolve query, callback when it's finished\n" + "todname(name)\n convert name to wire format\n" + "tojson(val)\n convert value to JSON\n" + "map(expr)\n run expression on all workers\n" + "net\n network configuration\n" + "cache\n network configuration\n" + "modules\n modules configuration\n" + "kres\n resolver services\n" + "trust_anchors\n configure trust anchors\n" + ; + lua_pushstring(L, help_str); + return 1; +} + +static bool update_privileges(int uid, int gid) +{ + if ((gid_t)gid != getgid()) { + if (setregid(gid, gid) < 0) { + return false; + } + } + if ((uid_t)uid != getuid()) { + if (setreuid(uid, uid) < 0) { + return false; + } + } + return true; +} + +/** Set process user/group. */ +static int l_setuser(lua_State *L) +{ + int n = lua_gettop(L); + if (n < 1 || !lua_isstring(L, 1)) + lua_error_p(L, "user(user[, group])"); + + /* Fetch UID/GID based on string identifiers. */ + struct passwd *user_pw = getpwnam(lua_tostring(L, 1)); + if (!user_pw) + lua_error_p(L, "invalid user name"); + int uid = user_pw->pw_uid; + int gid = getgid(); + if (n > 1 && lua_isstring(L, 2)) { + struct group *group_pw = getgrnam(lua_tostring(L, 2)); + if (!group_pw) + lua_error_p(L, "invalid group name"); + gid = group_pw->gr_gid; + } + /* Drop privileges */ + bool ret = update_privileges(uid, gid); + if (!ret) { + lua_error_maybe(L, errno); + } + lua_pushboolean(L, ret); + return 1; +} + +/** Quit current executable. */ +static int l_quit(lua_State *L) +{ + engine_stop(engine_luaget(L)); + return 0; +} + +/** Toggle verbose mode. */ +static int l_verbose(lua_State *L) +{ + if (lua_isboolean(L, 1) || lua_isnumber(L, 1)) { + kr_verbose_set(lua_toboolean(L, 1)); + } + lua_pushboolean(L, kr_verbose_status); + return 1; +} + +char *engine_get_hostname(struct engine *engine) { + static char hostname_str[KNOT_DNAME_MAXLEN]; + if (!engine) { + return NULL; + } + + if (!engine->hostname) { + if (gethostname(hostname_str, sizeof(hostname_str)) != 0) + return NULL; + return hostname_str; + } + return engine->hostname; +} + +int engine_set_hostname(struct engine *engine, const char *hostname) { + if (!engine || !hostname) { + return kr_error(EINVAL); + } + + char *new_hostname = strdup(hostname); + if (!new_hostname) { + return kr_error(ENOMEM); + } + if (engine->hostname) { + free(engine->hostname); + } + engine->hostname = new_hostname; + //network_new_hostname(&engine->net, engine); + + return 0; +} + +/** Return hostname. */ +static int l_hostname(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + if (lua_gettop(L) == 0) { + lua_pushstring(L, engine_get_hostname(engine)); + return 1; + } + if ((lua_gettop(L) != 1) || !lua_isstring(L, 1)) + lua_error_p(L, "hostname takes at most one parameter: (\"fqdn\")"); + + if (engine_set_hostname(engine, lua_tostring(L, 1)) != 0) + lua_error_p(L, "setting hostname failed"); + + lua_pushstring(L, engine_get_hostname(engine)); + return 1; +} + +/** Return server package version. */ +static int l_package_version(lua_State *L) +{ + lua_pushliteral(L, PACKAGE_VERSION); + return 1; +} + +/** Load root hints from zonefile. */ +// static int l_hint_root_file(lua_State *L) +// { +// struct engine *engine = engine_luaget(L); +// struct kr_context *ctx = &engine->resolver; +// const char *file = lua_tostring(L, 1); + +// const char *err = engine_hint_root_file(ctx, file); +// if (err) { +// if (!file) { +// file = ROOTHINTS; +// } +// lua_error_p(L, "error when opening '%s': %s", file, err); +// } else { +// lua_pushboolean(L, true); +// return 1; +// } +// } + +/** @internal for engine_hint_root_file */ +static void roothints_add(zs_scanner_t *zs) +{ + struct kr_zonecut *hints = zs->process.data; + if (!hints) { + return; + } + if (zs->r_type == KNOT_RRTYPE_A || zs->r_type == KNOT_RRTYPE_AAAA) { + kr_zonecut_add(hints, zs->r_owner, zs->r_data, zs->r_data_length); + } +} +const char* engine_hint_root_file(struct kr_context *ctx, const char *file) +{ + if (!file) { + file = ROOTHINTS; + } + if (strlen(file) == 0 || !ctx) { + return "invalid parameters"; + } + struct kr_zonecut *root_hints = &ctx->root_hints; + + zs_scanner_t zs; + if (zs_init(&zs, ".", 1, 0) != 0) { + return "not enough memory"; + } + if (zs_set_input_file(&zs, file) != 0) { + zs_deinit(&zs); + return "failed to open root hints file"; + } + + kr_zonecut_set(root_hints, (const uint8_t *)""); + zs_set_processing(&zs, roothints_add, NULL, root_hints); + zs_parse_all(&zs); + zs_deinit(&zs); + return NULL; +} + +/** Unpack JSON object to table */ +static void l_unpack_json(lua_State *L, JsonNode *table) +{ + /* Unpack POD */ + switch(table->tag) { + case JSON_STRING: lua_pushstring(L, table->string_); return; + case JSON_NUMBER: lua_pushnumber(L, table->number_); return; + case JSON_BOOL: lua_pushboolean(L, table->bool_); return; + default: break; + } + /* Unpack object or array into table */ + lua_newtable(L); + JsonNode *node = NULL; + json_foreach(node, table) { + /* Push node value */ + switch(node->tag) { + case JSON_OBJECT: /* as array */ + case JSON_ARRAY: l_unpack_json(L, node); break; + case JSON_STRING: lua_pushstring(L, node->string_); break; + case JSON_NUMBER: lua_pushnumber(L, node->number_); break; + case JSON_BOOL: lua_pushboolean(L, node->bool_); break; + default: continue; + } + /* Set table key */ + if (node->key) { + lua_setfield(L, -2, node->key); + } else { + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + } + } +} + +/** @internal Recursive Lua/JSON serialization. */ +static JsonNode *l_pack_elem(lua_State *L, int top) +{ + switch(lua_type(L, top)) { + case LUA_TSTRING: return json_mkstring(lua_tostring(L, top)); + case LUA_TNUMBER: return json_mknumber(lua_tonumber(L, top)); + case LUA_TBOOLEAN: return json_mkbool(lua_toboolean(L, top)); + case LUA_TTABLE: break; /* Table, iterate it. */ + default: return json_mknull(); + } + /* Use absolute indexes here, as the table may be nested. */ + JsonNode *node = NULL; + lua_pushnil(L); + while(lua_next(L, top) != 0) { + bool is_array = false; + if (!node) { + is_array = (lua_type(L, top + 1) == LUA_TNUMBER); + node = is_array ? json_mkarray() : json_mkobject(); + if (!node) { + return NULL; + } + } else { + is_array = node->tag == JSON_ARRAY; + } + + /* Insert to array/table. */ + JsonNode *val = l_pack_elem(L, top + 2); + if (is_array) { + json_append_element(node, val); + } else { + const char *key = lua_tostring(L, top + 1); + json_append_member(node, key, val); + } + lua_pop(L, 1); + } + /* Return empty object for empty tables. */ + return node ? node : json_mkobject(); +} + +/** @internal Serialize to string */ +static char *l_pack_json(lua_State *L, int top) +{ + JsonNode *root = l_pack_elem(L, top); + if (!root) { + return NULL; + } + char *result = json_encode(root); + json_delete(root); + return result; +} + +static int l_tojson(lua_State *L) +{ + auto_free char *json_str = l_pack_json(L, lua_gettop(L)); + if (!json_str) { + return 0; + } + lua_pushstring(L, json_str); + return 1; +} + +static int l_fromjson(lua_State *L) +{ + if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) + lua_error_p(L, "a JSON string is required"); + + const char *json_str = lua_tostring(L, 1); + JsonNode *root_node = json_decode(json_str); + + if (!root_node) + lua_error_p(L, "invalid JSON string"); + l_unpack_json(L, root_node); + json_delete(root_node); + + return 1; +} + +/** @internal Throw Lua error if expr is false */ +#define expr_checked(expr) \ + if (!(expr)) { lua_pushboolean(L, false); lua_rawseti(L, -2, lua_objlen(L, -2) + 1); continue; } + +static int l_map(lua_State *L) +{ + if (lua_gettop(L) != 1 || !lua_isstring(L, 1)) + lua_error_p(L, "map('string with a lua expression')"); + + struct engine *engine = engine_luaget(L); + const char *cmd = lua_tostring(L, 1); + uint32_t len = strlen(cmd); + lua_newtable(L); + + /* Execute on leader instance */ + int ntop = lua_gettop(L); + engine_cmd(L, cmd, true); + lua_settop(L, ntop + 1); /* Push only one return value to table */ + lua_rawseti(L, -2, 1); + + for (size_t i = 0; i < engine->ipc_set.len; ++i) { + int fd = engine->ipc_set.at[i]; + /* Send command */ + expr_checked(write(fd, &len, sizeof(len)) == sizeof(len)); + expr_checked(write(fd, cmd, len) == len); + /* Read response */ + uint32_t rlen = 0; + if (read(fd, &rlen, sizeof(rlen)) == sizeof(rlen)) { + expr_checked(rlen < UINT32_MAX); + auto_free char *rbuf = malloc(rlen + 1); + expr_checked(rbuf != NULL); + expr_checked(read(fd, rbuf, rlen) == rlen); + rbuf[rlen] = '\0'; + /* Unpack from JSON */ + JsonNode *root_node = json_decode(rbuf); + if (root_node) { + l_unpack_json(L, root_node); + } else { + lua_pushlstring(L, rbuf, rlen); + } + json_delete(root_node); + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + continue; + } + /* Didn't respond */ + lua_pushboolean(L, false); + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + } + return 1; +} + +#undef expr_checked + + +/* + * Engine API. + */ + +static int init_worker(struct engine *engine) +{ + /* Note: it had been zored by engine_init(). */ + /* Open resolution context */ + //engine->resolver.trust_anchors = map_make(NULL); + //engine->resolver.negative_anchors = map_make(NULL); + //engine->resolver.pool = engine->pool; + //engine->resolver.modules = &engine->modules; + //engine->resolver.cache_rtt_tout_retry_interval = KR_NS_TIMEOUT_RETRY_INTERVAL; + /* Create OPT RR */ + // engine->resolver.opt_rr = mm_alloc(engine->pool, sizeof(knot_rrset_t)); + // if (!engine->resolver.opt_rr) { + // return kr_error(ENOMEM); + // } + //knot_edns_init(engine->resolver.opt_rr, KR_EDNS_PAYLOAD, 0, KR_EDNS_VERSION, engine->pool); + + /* Use default TLS padding */ + //engine->resolver.tls_padding = -1; + + /* Empty init; filled via ./lua/config.lua */ + //kr_zonecut_init(&engine->resolver.root_hints, (const uint8_t *)"", engine->pool); + + /* Open NS rtt + reputation cache */ + // lru_create(&engine->resolver.cache_rtt, LRU_RTT_SIZE, NULL, NULL); + // lru_create(&engine->resolver.cache_rep, LRU_REP_SIZE, NULL, NULL); + // lru_create(&engine->resolver.cache_cookie, LRU_COOKIES_SIZE, NULL, NULL); + + /* Load basic modules */ + // engine_register(engine, "iterate", NULL, NULL); + // engine_register(engine, "validate", NULL, NULL); + // engine_register(engine, "cache", NULL, NULL); + + return array_push(engine->backends, kr_cdb_lmdb()); +} + +static int init_state(struct engine *engine) +{ + /* Initialize Lua state */ + engine->L = luaL_newstate(); + if (engine->L == NULL) { + return kr_error(ENOMEM); + } + /* Initialize used libraries. */ + luaL_openlibs(engine->L); + /* Global functions */ + lua_pushcfunction(engine->L, l_help); + lua_setglobal(engine->L, "help"); + lua_pushcfunction(engine->L, l_quit); + lua_setglobal(engine->L, "quit"); + lua_pushcfunction(engine->L, l_hostname); + lua_setglobal(engine->L, "hostname"); + lua_pushcfunction(engine->L, l_package_version); + lua_setglobal(engine->L, "package_version"); + lua_pushcfunction(engine->L, l_verbose); + lua_setglobal(engine->L, "verbose"); + lua_pushcfunction(engine->L, l_setuser); + lua_setglobal(engine->L, "user"); + // lua_pushcfunction(engine->L, l_hint_root_file); + // lua_setglobal(engine->L, "_hint_root_file"); + lua_pushliteral(engine->L, libknot_SONAME); + lua_setglobal(engine->L, "libknot_SONAME"); + lua_pushliteral(engine->L, libzscanner_SONAME); + lua_setglobal(engine->L, "libzscanner_SONAME"); + lua_pushcfunction(engine->L, l_tojson); + lua_setglobal(engine->L, "tojson"); + lua_pushcfunction(engine->L, l_fromjson); + lua_setglobal(engine->L, "fromjson"); + lua_pushcfunction(engine->L, l_map); + lua_setglobal(engine->L, "map"); + lua_pushlightuserdata(engine->L, engine); + lua_setglobal(engine->L, "__engine"); + return kr_ok(); +} + +/** + * Start luacov measurement and store results to file specified by + * KRESD_COVERAGE_STATS environment variable. + * Do nothing if the variable is not set. + */ +// static void init_measurement(struct engine *engine) +// { +// const char * const statspath = getenv("KRESD_COVERAGE_STATS"); +// if (!statspath) +// return; + +// char * snippet = NULL; +// int ret = asprintf(&snippet, +// "_luacov_runner = require('luacov.runner')\n" +// "_luacov_runner.init({\n" +// " statsfile = '%s',\n" +// " exclude = {'test', 'tapered', 'lua/5.1'},\n" +// "})\n" +// "jit.off()\n", statspath +// ); +// assert(ret > 0); (void)ret; + +// ret = luaL_loadstring(engine->L, snippet); +// assert(ret == 0); +// lua_call(engine->L, 0, 0); +// free(snippet); +// } + +int init_lua(struct engine *engine) { + if (!engine) { + return kr_error(EINVAL); + } + + /* Use libdir path for including Lua scripts */ + char l_paths[MAXPATHLEN] = { 0 }; + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wformat" /* %1$ is not in C standard */ + /* Save original package.path to package._path */ + snprintf(l_paths, MAXPATHLEN - 1, + "if package._path == nil then package._path = package.path end\n" + "package.path = '%1$s/?.lua;%1$s/?/init.lua;'..package._path\n" + "if package._cpath == nil then package._cpath = package.cpath end\n" + "package.cpath = '%1$s/?%2$s;'..package._cpath\n", + LIBDIR, LIBEXT); + #pragma GCC diagnostic pop + + int ret = l_dobytecode(engine->L, l_paths, strlen(l_paths), ""); + if (ret != 0) { + lua_pop(engine->L, 1); + return ret; + } + return 0; +} + + +int engine_init(struct engine *engine, knot_mm_t *pool) +{ + if (engine == NULL) { + return kr_error(EINVAL); + } + + memset(engine, 0, sizeof(*engine)); + engine->pool = pool; + + /* Initialize state */ + int ret = init_state(engine); + if (ret != 0) { + engine_deinit(engine); + return ret; + } + + // init_measurement(engine); + + /* Initialize worker */ + ret = init_worker(engine); + if (ret != 0) { + engine_deinit(engine); + return ret; + } + + /* Initialize watcher */ + watcher_init(&engine->watcher, uv_default_loop()); + + /* Initialize lua */ + ret = init_lua(engine); + if (ret != 0) { + engine_deinit(engine); + return ret; + } + + return ret; +} + +/** Unregister a (found) module */ +static void engine_unload(struct engine *engine, struct kr_module *module) +{ + auto_free char *name = module->name ? strdup(module->name) : NULL; + kr_module_unload(module); /* beware: lua/C mix, could be confusing */ + /* Clear in Lua world, but not for embedded modules ('cache' in particular). */ + if (name && !kr_module_get_embedded(name)) { + lua_pushnil(engine->L); + lua_setglobal(engine->L, name); + } + free(module); +} + +void engine_deinit(struct engine *engine) +{ + if (engine == NULL) { + return; + } + if (!engine->L) { + assert(false); + return; + } + /* Only close sockets and services; no need to clean up mempool. */ + + /* Network deinit is split up. We first need to stop listening, + * then we can unload modules during which we still want + * e.g. the endpoint kind registry to work (inside ->net), + * and this registry deinitization uses the lua state. */ + //network_close_force(&engine->net); + for (size_t i = 0; i < engine->ipc_set.len; ++i) { + close(engine->ipc_set.at[i]); + } + for (size_t i = 0; i < engine->modules.len; ++i) { + engine_unload(engine, engine->modules.at[i]); + } + //kr_zonecut_deinit(&engine->resolver.root_hints); + //kr_cache_close(&engine->watcher.cache); + + /* The LRUs are currently malloc-ated and need to be freed. */ + //lru_free(engine->resolver.cache_rtt); + //lru_free(engine->resolver.cache_rep); + //lru_free(engine->resolver.cache_cookie); + + watcher_deinit(&engine->watcher); + ffimodule_deinit(engine->L); + lua_close(engine->L); + + /* Free data structures */ + array_clear(engine->modules); + array_clear(engine->backends); + array_clear(engine->ipc_set); + //kr_ta_clear(&engine->resolver.trust_anchors); + //(&engine->resolver.negative_anchors); + free(engine->hostname); +} + +int engine_pcall(lua_State *L, int argc) +{ + return lua_pcall(L, argc, LUA_MULTRET, 0); +} + +int engine_cmd(lua_State *L, const char *str, bool raw) +{ + if (L == NULL) { + return kr_error(ENOEXEC); + } + + /* Evaluate results */ + lua_getglobal(L, "eval_cmd"); + lua_pushstring(L, str); + lua_pushboolean(L, raw); + + /* Check result. */ + return engine_pcall(L, 2); +} + +int engine_ipc(struct engine *engine, const char *expr) +{ + if (engine == NULL || engine->L == NULL) { + return kr_error(ENOEXEC); + } + + /* Run expression and serialize response. */ + engine_cmd(engine->L, expr, true); + if (lua_gettop(engine->L) > 0) { + l_tojson(engine->L); + return 1; + } else { + return 0; + } +} + +int engine_load_sandbox(struct engine *engine) +{ + /* Init environment */ + int ret = luaL_dofile(engine->L, LIBDIR "/sandbox-watcher.lua"); + if (ret != 0) { + fprintf(stderr, "[system] error %s\n", lua_tostring(engine->L, -1)); + lua_pop(engine->L, 1); + return kr_error(ENOEXEC); + } + ret = ffimodule_init(engine->L); + return ret; +} + +int engine_loadconf(struct engine *engine, const char *config_path) +{ + assert(config_path != NULL); + + char cwd[PATH_MAX]; + get_workdir(cwd, sizeof(cwd)); + kr_log_verbose("[system] loading config '%s' (workdir '%s')\n", config_path, cwd); + + int ret = luaL_dofile(engine->L, config_path); + if (ret != 0) { + fprintf(stderr, "[system] error while loading config: " + "%s (workdir '%s')\n", lua_tostring(engine->L, -1), cwd); + lua_pop(engine->L, 1); + } + return ret; +} + +int engine_start(struct engine *engine) +{ + /* Clean up stack */ + lua_settop(engine->L, 0); + + return kr_ok(); +} + +void engine_stop(struct engine *engine) +{ + if (!engine) { + return; + } + uv_stop(uv_default_loop()); +} + +/** @internal Find matching module */ +static size_t module_find(module_array_t *mod_list, const char *name) +{ + size_t found = mod_list->len; + for (size_t i = 0; i < mod_list->len; ++i) { + struct kr_module *mod = mod_list->at[i]; + if (strcmp(mod->name, name) == 0) { + found = i; + break; + } + } + return found; +} + +int engine_register(struct engine *engine, const char *name, const char *precedence, const char* ref) +{ + if (engine == NULL || name == NULL) { + assert(!EINVAL); + return kr_error(EINVAL); + } + /* Make sure module is unloaded */ + (void) engine_unregister(engine, name); + /* Find the index of referenced module. */ + module_array_t *mod_list = &engine->modules; + size_t ref_pos = mod_list->len; + if (precedence && ref) { + ref_pos = module_find(mod_list, ref); + if (ref_pos >= mod_list->len) { + return kr_error(EIDRM); + } + } + /* Attempt to load binary module */ + struct kr_module *module = malloc(sizeof(*module)); + if (!module) { + return kr_error(ENOMEM); + } + module->data = engine; /*< some outside modules may still use this value */ + + int ret = kr_module_load(module, name, LIBDIR "/kres_modules"); + if (ret == 0) { + /* We have a C module, loaded and init() was called. + * Now we need to prepare the lua side. */ + lua_State *L = engine->L; + lua_getglobal(L, "modules_create_table_for_c"); + lua_pushpointer(L, module); + if (lua_isnil(L, -2)) { + /* When loading the three embedded modules, we don't + * have the "modules_*" lua function yet, but fortunately + * we don't need it there. Let's just check they're embedded. + * TODO: solve this better *without* breaking stuff. */ + lua_pop(L, 2); + if (module->lib != RTLD_DEFAULT) { + ret = kr_error(1); + lua_pushliteral(L, "missing modules_create_table_for_c()"); + } + } else { + ret = engine_pcall(L, 1); + } + if (ret) { + kr_log_error("[system] internal error when loading C module %s: %s\n", + module->name, lua_tostring(L, -1)); + lua_pop(L, 1); + assert(false); /* probably not critical, but weird */ + } + + } else if (ret == kr_error(ENOENT)) { + /* No luck with C module, so try to load and .init() lua module. */ + ret = ffimodule_register_lua(engine, module, name); + if (ret != 0) { + kr_log_error("[system] failed to load module '%s'\n", name); + } + + } else if (ret == kr_error(ENOTSUP)) { + /* Print a more helpful message when module is linked against an old resolver ABI. */ + kr_log_error("[system] module '%s' links to unsupported ABI, please rebuild it\n", name); + } + + if (ret != 0) { + engine_unload(engine, module); + return ret; + } + + /* Push to the right place in engine->modules */ + if (array_push(engine->modules, module) < 0) { + engine_unload(engine, module); + return kr_error(ENOMEM); + } + if (precedence) { + struct kr_module **arr = mod_list->at; + size_t emplacement = mod_list->len; + if (strcasecmp(precedence, ">") == 0) { + if (ref_pos + 1 < mod_list->len) + emplacement = ref_pos + 1; /* Insert after target */ + } + if (strcasecmp(precedence, "<") == 0) { + emplacement = ref_pos; /* Insert at target */ + } + /* Move the tail if it has some elements. */ + if (emplacement + 1 < mod_list->len) { + memmove(&arr[emplacement + 1], &arr[emplacement], sizeof(*arr) * (mod_list->len - (emplacement + 1))); + arr[emplacement] = module; + } + } + + return kr_ok(); +} + +int engine_unregister(struct engine *engine, const char *name) +{ + module_array_t *mod_list = &engine->modules; + size_t found = module_find(mod_list, name); + if (found < mod_list->len) { + engine_unload(engine, mod_list->at[found]); + array_del(*mod_list, found); + return kr_ok(); + } + + return kr_error(ENOENT); +} + +struct engine *engine_luaget(lua_State *L) +{ + lua_getglobal(L, "__engine"); + struct engine *engine = lua_touserdata(L, -1); + if (!engine) luaL_error(L, "internal error, empty engine pointer"); + lua_pop(L, 1); + return engine; +} diff --git a/utils/watcher/engine.h b/utils/watcher/engine.h new file mode 100644 index 000000000..0e0cf678a --- /dev/null +++ b/utils/watcher/engine.h @@ -0,0 +1,84 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +/* + * @internal These are forward decls to allow building modules with engine but without Lua. + */ +struct lua_State; + +#include "lib/utils.h" +#include "lib/resolve.h" +#include "network.h" + +#include "watcher.h" + +/* @internal Array of file descriptors shorthand. */ +typedef array_t(int) fd_array_t; + +struct engine { + struct kr_context resolver; + struct watcher_context watcher; + struct network net; + module_array_t modules; + array_t(const struct kr_cdb_api *) backends; + fd_array_t ipc_set; + knot_mm_t *pool; + char *hostname; + struct lua_State *L; +}; + +int engine_init(struct engine *engine, knot_mm_t *pool); +void engine_deinit(struct engine *engine); + +/** Perform a lua command within the sandbox. + * + * @return zero on success. + * The result will be returned on the lua stack - an error message in case of failure. + * http://www.lua.org/manual/5.1/manual.html#lua_pcall */ +int engine_cmd(struct lua_State *L, const char *str, bool raw); + +/** Execute current chunk in the sandbox */ +int engine_pcall(struct lua_State *L, int argc); + +int engine_ipc(struct engine *engine, const char *expr); + + +int engine_load_sandbox(struct engine *engine); +int engine_loadconf(struct engine *engine, const char *config_path); + +/** Start the lua engine and execute the config. */ +int engine_start(struct engine *engine); +void engine_stop(struct engine *engine); +int engine_register(struct engine *engine, const char *name, const char *precedence, const char* ref); +int engine_unregister(struct engine *engine, const char *name); + +/** Return engine light userdata. */ +struct engine *engine_luaget(struct lua_State *L); + +/** Set/get the per engine hostname */ +char *engine_get_hostname(struct engine *engine); +int engine_set_hostname(struct engine *engine, const char *hostname); + +/** Load root hints from a zonefile (or config-time default if NULL). + * + * @return error message or NULL (statically allocated) + * @note exported to be usable from the hints module. + */ +KR_EXPORT +const char* engine_hint_root_file(struct kr_context *ctx, const char *file); + diff --git a/utils/watcher/ffimodule.c b/utils/watcher/ffimodule.c new file mode 100644 index 000000000..d587ed7a7 --- /dev/null +++ b/utils/watcher/ffimodule.c @@ -0,0 +1,307 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include +#include +#include + +#include "bindings/impl.h" +#include "engine.h" +#include "ffimodule.h" +#include "worker.h" +#include "lib/module.h" +#include "lib/layer.h" + +/** @internal Slots for layer callbacks. + * Each slot ID corresponds to Lua reference in module API. */ +enum { + SLOT_begin = 0, + SLOT_reset, + SLOT_finish, + SLOT_consume, + SLOT_produce, + SLOT_checkout, + SLOT_answer_finalize, + SLOT_count /* dummy, must be the last */ +}; + +/** Lua registry indices for functions that wrap layer callbacks (shared by all lua modules). */ +static int l_ffi_wrap_slots[SLOT_count] = { 0 }; + +/** @internal Continue with coroutine. */ +static void l_ffi_resume_cb(uv_idle_t *check) +{ + lua_State *L = check->data; + int status = lua_resume(L, 0); + if (status != LUA_YIELD) { + uv_idle_stop(check); /* Stop coroutine */ + uv_close((uv_handle_t *)check, (uv_close_cb)free); + } + lua_pop(L, lua_gettop(L)); +} + +/** @internal Schedule deferred continuation. */ +static int l_ffi_defer(lua_State *L) +{ + uv_idle_t *check = malloc(sizeof(*check)); + if (!check) { + return kr_error(ENOMEM); + } + uv_idle_init(uv_default_loop(), check); + check->data = L; + return uv_idle_start(check, l_ffi_resume_cb); +} + +/** @internal Helper for calling the entrypoint, for kr_module functions. */ +static int l_ffi_call_mod(lua_State *L, int argc) +{ + int status = lua_pcall(L, argc, 1, 0); + if (status != 0) { + kr_log_error("error: %s\n", lua_tostring(L, -1)); + lua_pop(L, 1); + return kr_error(EIO); + } + if (lua_isnumber(L, -1)) { /* Return code */ + status = lua_tointeger(L, -1); + } else if (lua_isthread(L, -1)) { /* Continuations */ + /* TODO: unused, possibly in a bad shape. Meant KR_STATE_YIELD? */ + assert(!ENOTSUP); + status = l_ffi_defer(lua_tothread(L, -1)); + } + lua_pop(L, 1); + return status; +} + +/** Common part of calling modname.(de)init in lua. + * The function to call should be on top of the stack and it gets popped. */ +static int l_ffi_modcb(lua_State *L, struct kr_module *module) +{ + if (lua_isnil(L, -1)) { + lua_pop(L, 1); /* .(de)init == nil, maybe even the module table doesn't exist */ + return kr_ok(); + } + lua_getglobal(L, "modules_ffi_wrap_modcb"); + lua_insert(L, -2); /* swap with .(de)init */ + lua_pushpointer(L, module); + if (lua_pcall(L, 2, 0, 0) == 0) + return kr_ok(); + kr_log_error("error: %s\n", lua_tostring(L, -1)); + lua_pop(L, 1); + return kr_error(1); +} + +static int l_ffi_deinit(struct kr_module *module) +{ + /* Call .deinit(), if it exists. */ + lua_State *L = the_worker->engine->L; + lua_getglobal(L, module->name); + lua_getfield(L, -1, "deinit"); + const int ret = l_ffi_modcb(L, module); + lua_pop(L, 1); /* the module's table */ + + const kr_layer_api_t *api = module->layer; + if (!api) { + return ret; + } + /* Unregister layer callback references from registry. */ + for (int si = 0; si < SLOT_count; ++si) { + if (api->cb_slots[si] > 0) { + luaL_unref(L, LUA_REGISTRYINDEX, api->cb_slots[si]); + } + } + free_const(api); + return ret; +} + +kr_layer_t kr_layer_t_static; + +/** @internal Helper for calling a layer Lua function by e.g. SLOT_begin. */ +static int l_ffi_call_layer(kr_layer_t *ctx, int slot_ix) +{ + const int wrap_slot = l_ffi_wrap_slots[slot_ix]; + const int cb_slot = ctx->api->cb_slots[slot_ix]; + assert(wrap_slot > 0 && cb_slot > 0); + lua_State *L = the_worker->engine->L; + lua_rawgeti(L, LUA_REGISTRYINDEX, wrap_slot); + lua_rawgeti(L, LUA_REGISTRYINDEX, cb_slot); + /* We pass the content of *ctx via a global structure to avoid + * lua (full) userdata, as that's relatively expensive (GC-allocated). + * Performance: copying isn't ideal, but it's not visible in profiles. */ + memcpy(&kr_layer_t_static, ctx, sizeof(*ctx)); + const int ret = l_ffi_call_mod(L, 1); + /* The return codes are mixed at this point. We need to return KR_STATE_* */ + return ret < 0 ? KR_STATE_FAIL : ret; +} + +static int l_ffi_layer_begin(kr_layer_t *ctx) +{ + return l_ffi_call_layer(ctx, SLOT_begin); +} + +static int l_ffi_layer_reset(kr_layer_t *ctx) +{ + return l_ffi_call_layer(ctx, SLOT_reset); +} + +static int l_ffi_layer_finish(kr_layer_t *ctx) +{ + ctx->pkt = ctx->req->answer; + return l_ffi_call_layer(ctx, SLOT_finish); +} + +static int l_ffi_layer_consume(kr_layer_t *ctx, knot_pkt_t *pkt) +{ + if (ctx->state & KR_STATE_FAIL) { + return ctx->state; /* Already failed, skip */ + } + ctx->pkt = pkt; + return l_ffi_call_layer(ctx, SLOT_consume); +} + +static int l_ffi_layer_produce(kr_layer_t *ctx, knot_pkt_t *pkt) +{ + if (ctx->state & KR_STATE_FAIL) { + return ctx->state; /* Already failed, skip */ + } + ctx->pkt = pkt; + return l_ffi_call_layer(ctx, SLOT_produce); +} + +static int l_ffi_layer_checkout(kr_layer_t *ctx, knot_pkt_t *pkt, + struct sockaddr *dst, int type) +{ + if (ctx->state & KR_STATE_FAIL) { + return ctx->state; /* Already failed, skip */ + } + ctx->pkt = pkt; + ctx->dst = dst; + ctx->is_stream = (type == SOCK_STREAM); + return l_ffi_call_layer(ctx, SLOT_checkout); +} + +static int l_ffi_layer_answer_finalize(kr_layer_t *ctx) +{ + return l_ffi_call_layer(ctx, SLOT_answer_finalize); +} + +int ffimodule_init(lua_State *L) +{ + /* Wrappers defined in ./lua/sandbox.lua */ + /* for API: (int state, kr_request_t *req) */ + lua_getglobal(L, "modules_ffi_layer_wrap1"); + const int wrap1 = luaL_ref(L, LUA_REGISTRYINDEX); + /* for API: (int state, kr_request_t *req, knot_pkt_t *) */ + lua_getglobal(L, "modules_ffi_layer_wrap2"); + const int wrap2 = luaL_ref(L, LUA_REGISTRYINDEX); + lua_getglobal(L, "modules_ffi_layer_wrap_checkout"); + const int wrap_checkout = luaL_ref(L, LUA_REGISTRYINDEX); + if (wrap1 == LUA_REFNIL || wrap2 == LUA_REFNIL || wrap_checkout == LUA_REFNIL) { + return kr_error(ENOENT); + } + + const int slots[SLOT_count] = { + [SLOT_begin] = wrap1, + [SLOT_reset] = wrap1, + [SLOT_finish] = wrap2, + [SLOT_consume] = wrap2, + [SLOT_produce] = wrap2, + [SLOT_checkout] = wrap_checkout, + [SLOT_answer_finalize] = wrap1, + }; + memcpy(l_ffi_wrap_slots, slots, sizeof(l_ffi_wrap_slots)); + return kr_ok(); +} +void ffimodule_deinit(lua_State *L) +{ + /* Unref each wrapper function from lua. + * It's probably useless, as we're about to destroy lua_State, but... */ + const int wrapsIndices[] = { + SLOT_begin, + SLOT_consume, + SLOT_checkout, + }; + for (int i = 0; i < sizeof(wrapsIndices) / sizeof(wrapsIndices[0]); ++i) { + luaL_unref(L, LUA_REGISTRYINDEX, l_ffi_wrap_slots[wrapsIndices[i]]); + } +} + +/** @internal Conditionally register layer trampoline + * @warning Expects 'module.layer' to be on top of Lua stack. */ +#define LAYER_REGISTER(L, api, name) do { \ + int *cb_slot = (api)->cb_slots + SLOT_ ## name; \ + lua_getfield((L), -1, #name); \ + if (!lua_isnil((L), -1)) { \ + (api)->name = l_ffi_layer_ ## name; \ + *cb_slot = luaL_ref((L), LUA_REGISTRYINDEX); \ + } else { \ + lua_pop((L), 1); \ + } \ +} while(0) + +/** @internal Create C layer api wrapper. */ +static kr_layer_api_t *l_ffi_layer_create(lua_State *L, struct kr_module *module) +{ + /* Fabricate layer API wrapping the Lua functions + * reserve slots after it for references to Lua callbacks. */ + const size_t api_length = offsetof(kr_layer_api_t, cb_slots) + + (SLOT_count * sizeof(module->layer->cb_slots[0])); + kr_layer_api_t *api = malloc(api_length); + if (api) { + memset(api, 0, api_length); + LAYER_REGISTER(L, api, begin); + LAYER_REGISTER(L, api, finish); + LAYER_REGISTER(L, api, consume); + LAYER_REGISTER(L, api, produce); + LAYER_REGISTER(L, api, checkout); + LAYER_REGISTER(L, api, answer_finalize); + LAYER_REGISTER(L, api, reset); + } + return api; +} + +#undef LAYER_REGISTER + +int ffimodule_register_lua(struct engine *engine, struct kr_module *module, const char *name) +{ + /* Register module in Lua */ + lua_State *L = engine->L; + lua_getglobal(L, "require"); + lua_pushfstring(L, "kres_modules.%s", name); + if (lua_pcall(L, 1, LUA_MULTRET, 0) != 0) { + kr_log_error("error: %s\n", lua_tostring(L, -1)); + lua_pop(L, 1); + return kr_error(ENOENT); + } + lua_setglobal(L, name); + lua_getglobal(L, name); + + /* Create FFI module with trampolined functions. */ + memset(module, 0, sizeof(*module)); + module->name = strdup(name); + module->deinit = &l_ffi_deinit; + /* Bake layer API if defined in module */ + lua_getfield(L, -1, "layer"); + if (!lua_isnil(L, -1)) { + module->layer = l_ffi_layer_create(L, module); + } + lua_pop(L, 1); /* .layer table */ + + /* Now call .init(), if it exists. */ + lua_getfield(L, -1, "init"); + const int ret = l_ffi_modcb(L, module); + lua_pop(L, 1); /* the module's table */ + return ret; +} diff --git a/utils/watcher/ffimodule.h b/utils/watcher/ffimodule.h new file mode 100644 index 000000000..f864f5f51 --- /dev/null +++ b/utils/watcher/ffimodule.h @@ -0,0 +1,48 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include "lib/defines.h" +#include "lib/layer.h" +#include +struct engine; +struct kr_module; + +/** + * Register Lua module as a FFI module. + * This fabricates a standard module interface, + * that trampolines to the Lua module methods. + * + * @note Lua module is loaded in it's own coroutine, + * so it's possible to yield and resume at arbitrary + * places except deinit() + * + * @param engine daemon engine + * @param module prepared module + * @param name module name + * @return 0 or an error + */ +int ffimodule_register_lua(struct engine *engine, struct kr_module *module, const char *name); + +int ffimodule_init(lua_State *L); +void ffimodule_deinit(lua_State *L); + +/** Static storage for faster passing of layer function parameters to lua callbacks. + * + * We don't need to declare it in a header, but let's give it visibility. */ +KR_EXPORT kr_layer_t kr_layer_t_static; + diff --git a/utils/watcher/io.c b/utils/watcher/io.c new file mode 100644 index 000000000..2b186070d --- /dev/null +++ b/utils/watcher/io.c @@ -0,0 +1,515 @@ +/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include +#include +#include +#include +#include + +#include "io.h" +#include "network.h" +#include "worker.h" +#include "tls.h" +#include "session.h" + +#define negotiate_bufsize(func, handle, bufsize_want) do { \ + int bufsize = 0; (func)((handle), &bufsize); \ + if (bufsize < (bufsize_want)) { \ + bufsize = (bufsize_want); \ + (func)((handle), &bufsize); \ + } \ +} while (0) + +static void check_bufsize(uv_handle_t* handle) +{ + return; /* TODO: resurrect after https://github.com/libuv/libuv/issues/419 */ + /* We want to buffer at least N waves in advance. + * This is magic presuming we can pull in a whole recvmmsg width in one wave. + * Linux will double this the bufsize wanted. + */ + const int bufsize_want = 2 * sizeof( ((struct worker_ctx *)NULL)->wire_buf ) ; + negotiate_bufsize(uv_recv_buffer_size, handle, bufsize_want); + negotiate_bufsize(uv_send_buffer_size, handle, bufsize_want); +} + +#undef negotiate_bufsize + +static void handle_getbuf(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) +{ + /* UDP sessions use worker buffer for wire data, + * TCP sessions use session buffer for wire data + * (see session_set_handle()). + * TLS sessions use buffer from TLS context. + * The content of the worker buffer is + * guaranteed to be unchanged only for the duration of + * udp_read() and tcp_read(). + */ + struct session *s = handle->data; + if (!session_flags(s)->has_tls) { + buf->base = (char *) session_wirebuf_get_free_start(s); + buf->len = session_wirebuf_get_free_size(s); + } else { + struct tls_common_ctx *ctx = session_tls_get_common_ctx(s); + buf->base = (char *) ctx->recv_buf; + buf->len = sizeof(ctx->recv_buf); + } +} + +void udp_recv(uv_udp_t *handle, ssize_t nread, const uv_buf_t *buf, + const struct sockaddr *addr, unsigned flags) +{ + uv_loop_t *loop = handle->loop; + struct worker_ctx *worker = loop->data; + struct session *s = handle->data; + if (session_flags(s)->closing) { + return; + } + if (nread <= 0) { + if (nread < 0) { /* Error response, notify resolver */ + worker_submit(s, NULL, NULL); + } /* nread == 0 is for freeing buffers, we don't need to do this */ + return; + } + if (addr->sa_family == AF_UNSPEC) { + return; + } + if (session_flags(s)->outgoing) { + const struct sockaddr *peer = session_get_peer(s); + assert(peer->sa_family != AF_UNSPEC); + if (kr_sockaddr_cmp(peer, addr) != 0) { + kr_log_verbose("[io] <= ignoring UDP from unexpected address '%s'\n", + kr_straddr(addr)); + return; + } + } + ssize_t consumed = session_wirebuf_consume(s, (const uint8_t *)buf->base, + nread); + assert(consumed == nread); (void)consumed; + session_wirebuf_process(s, addr); + session_wirebuf_discard(s); + mp_flush(worker->pkt_pool.ctx); +} + +static int family_to_freebind_option(sa_family_t sa_family, int *level, int *name) +{ + switch (sa_family) { + case AF_INET: + *level = IPPROTO_IP; +#if defined(IP_FREEBIND) + *name = IP_FREEBIND; +#elif defined(IP_BINDANY) + *name = IP_BINDANY; +#else + return kr_error(ENOTSUP); +#endif + break; + case AF_INET6: +#if defined(IP_FREEBIND) + *level = IPPROTO_IP; + *name = IP_FREEBIND; +#elif defined(IPV6_BINDANY) + *level = IPPROTO_IPV6; + *name = IPV6_BINDANY; +#else + return kr_error(ENOTSUP); +#endif + break; + default: + return kr_error(ENOTSUP); + } + return kr_ok(); +} + +int io_bind(const struct sockaddr *addr, int type, const endpoint_flags_t *flags) +{ + const int fd = socket(addr->sa_family, type, 0); + if (fd < 0) return kr_error(errno); + + int yes = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes))) + return kr_error(errno); + +#ifdef SO_REUSEPORT_LB + if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT_LB, &yes, sizeof(yes))) + return kr_error(errno); +#elif defined(SO_REUSEPORT) && defined(__linux__) /* different meaning on (Free)BSD */ + if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &yes, sizeof(yes))) + return kr_error(errno); +#endif + +#ifdef IPV6_V6ONLY + if (addr->sa_family == AF_INET6 + && setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &yes, sizeof(yes))) + return kr_error(errno); +#endif + if (flags != NULL && flags->freebind) { + int optlevel; + int optname; + int ret = family_to_freebind_option(addr->sa_family, &optlevel, &optname); + if (ret) return kr_error(ret); + if (setsockopt(fd, optlevel, optname, &yes, sizeof(yes))) + return kr_error(errno); + } + + if (bind(fd, addr, kr_sockaddr_len(addr))) + return kr_error(errno); + + return fd; +} + +int io_listen_udp(uv_loop_t *loop, uv_udp_t *handle, int fd) +{ + if (!handle) { + return kr_error(EINVAL); + } + int ret = uv_udp_init(loop, handle); + if (ret) return ret; + + ret = uv_udp_open(handle, fd); + if (ret) return ret; + + uv_handle_t *h = (uv_handle_t *)handle; + check_bufsize(h); + /* Handle is already created, just create context. */ + struct session *s = session_new(h, false); + assert(s); + session_flags(s)->outgoing = false; + + int socklen = sizeof(union inaddr); + ret = uv_udp_getsockname(handle, session_get_sockname(s), &socklen); + if (ret) { + kr_log_error("ERROR: getsockname failed: %s\n", uv_strerror(ret)); + abort(); /* It might be nontrivial not to leak something here. */ + } + + return io_start_read(h); +} + +void tcp_timeout_trigger(uv_timer_t *timer) +{ + struct session *s = timer->data; + + assert(!session_flags(s)->closing); + + struct worker_ctx *worker = timer->loop->data; + + if (!session_tasklist_is_empty(s)) { + int finalized = session_tasklist_finalize_expired(s); + worker->stats.timeout += finalized; + /* session_tasklist_finalize_expired() may call worker_task_finalize(). + * If session is a source session and there were IO errors, + * worker_task_finalize() can filnalize all tasks and close session. */ + if (session_flags(s)->closing) { + return; + } + + } + if (!session_tasklist_is_empty(s)) { + uv_timer_stop(timer); + session_timer_start(s, tcp_timeout_trigger, + KR_RESOLVE_TIME_LIMIT / 2, + KR_RESOLVE_TIME_LIMIT / 2); + } else { + /* Normally it should not happen, + * but better to check if there anything in this list. */ + while (!session_waitinglist_is_empty(s)) { + struct qr_task *t = session_waitinglist_pop(s, false); + worker_task_finalize(t, KR_STATE_FAIL); + worker_task_unref(t); + worker->stats.timeout += 1; + if (session_flags(s)->closing) { + return; + } + } + const struct engine *engine = worker->engine; + const struct network *net = &engine->net; + uint64_t idle_in_timeout = net->tcp.in_idle_timeout; + uint64_t last_activity = session_last_activity(s); + uint64_t idle_time = kr_now() - last_activity; + if (idle_time < idle_in_timeout) { + idle_in_timeout -= idle_time; + uv_timer_stop(timer); + session_timer_start(s, tcp_timeout_trigger, + idle_in_timeout, idle_in_timeout); + } else { + struct sockaddr *peer = session_get_peer(s); + char *peer_str = kr_straddr(peer); + kr_log_verbose("[io] => closing connection to '%s'\n", + peer_str ? peer_str : ""); + if (session_flags(s)->outgoing) { + worker_del_tcp_waiting(worker, peer); + worker_del_tcp_connected(worker, peer); + } + session_close(s); + } + } +} + +static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) +{ + struct session *s = handle->data; + assert(s && session_get_handle(s) == (uv_handle_t *)handle && + handle->type == UV_TCP); + + if (session_flags(s)->closing) { + return; + } + + /* nread might be 0, which does not indicate an error or EOF. + * This is equivalent to EAGAIN or EWOULDBLOCK under read(2). */ + if (nread == 0) { + return; + } + + if (nread < 0 || !buf->base) { + if (kr_verbose_status) { + struct sockaddr *peer = session_get_peer(s); + char *peer_str = kr_straddr(peer); + kr_log_verbose("[io] => connection to '%s' closed by peer (%s)\n", + peer_str ? peer_str : "", + uv_strerror(nread)); + } + worker_end_tcp(s); + return; + } + + ssize_t consumed = 0; + const uint8_t *data = (const uint8_t *)buf->base; + ssize_t data_len = nread; + if (session_flags(s)->has_tls) { + /* buf->base points to start of the tls receive buffer. + Decode data free space in session wire buffer. */ + consumed = tls_process_input_data(s, (const uint8_t *)buf->base, nread); + if (consumed < 0) { + if (kr_verbose_status) { + struct sockaddr *peer = session_get_peer(s); + char *peer_str = kr_straddr(peer); + kr_log_verbose("[io] => connection to '%s': " + "error processing TLS data, close\n", + peer_str ? peer_str : ""); + } + worker_end_tcp(s); + return; + } else if (consumed == 0) { + return; + } + data = session_wirebuf_get_free_start(s); + data_len = consumed; + } + + /* data points to start of the free space in session wire buffer. + Simple increase internal counter. */ + consumed = session_wirebuf_consume(s, data, data_len); + assert(consumed == data_len); + + int ret = session_wirebuf_process(s, session_get_peer(s)); + if (ret < 0) { + /* An error has occurred, close the session. */ + worker_end_tcp(s); + } + session_wirebuf_compress(s); + struct worker_ctx *worker = handle->loop->data; + mp_flush(worker->pkt_pool.ctx); +} + +static void _tcp_accept(uv_stream_t *master, int status, bool tls) +{ + if (status != 0) { + return; + } + + struct worker_ctx *worker = the_worker; + uv_tcp_t *client = malloc(sizeof(uv_tcp_t)); + if (!client) { + return; + } + int res = io_create(master->loop, (uv_handle_t *)client, + SOCK_STREAM, AF_UNSPEC, tls); + if (res) { + if (res == UV_EMFILE) { + worker->too_many_open = true; + worker->rconcurrent_highwatermark = worker->stats.rconcurrent; + } + /* Since res isn't OK struct session wasn't allocated \ borrowed. + * We must release client handle only. + */ + free(client); + return; + } + + /* struct session was allocated \ borrowed from memory pool. */ + struct session *s = client->data; + assert(session_flags(s)->outgoing == false); + assert(session_flags(s)->has_tls == tls); + + if (uv_accept(master, (uv_stream_t *)client) != 0) { + /* close session, close underlying uv handles and + * deallocate (or return to memory pool) memory. */ + session_close(s); + return; + } + + /* Get peer's and our address. We apparently get specific sockname here + * even if we listened on a wildcard address. */ + struct sockaddr *sa = session_get_peer(s); + int sa_len = sizeof(struct sockaddr_in6); + int ret = uv_tcp_getpeername(client, sa, &sa_len); + if (ret || sa->sa_family == AF_UNSPEC) { + session_close(s); + return; + } + sa = session_get_sockname(s); + sa_len = sizeof(struct sockaddr_in6); + ret = uv_tcp_getsockname(client, sa, &sa_len); + if (ret || sa->sa_family == AF_UNSPEC) { + session_close(s); + return; + } + + /* Set deadlines for TCP connection and start reading. + * It will re-check every half of a request time limit if the connection + * is idle and should be terminated, this is an educated guess. */ + + const struct network *net = &worker->engine->net; + uint64_t idle_in_timeout = net->tcp.in_idle_timeout; + + uint64_t timeout = KR_CONN_RTT_MAX / 2; + if (tls) { + timeout += TLS_MAX_HANDSHAKE_TIME; + struct tls_ctx_t *ctx = session_tls_get_server_ctx(s); + if (!ctx) { + ctx = tls_new(worker); + if (!ctx) { + session_close(s); + return; + } + ctx->c.session = s; + ctx->c.handshake_state = TLS_HS_IN_PROGRESS; + session_tls_set_server_ctx(s, ctx); + } + } + session_timer_start(s, tcp_timeout_trigger, timeout, idle_in_timeout); + io_start_read((uv_handle_t *)client); +} + +static void tcp_accept(uv_stream_t *master, int status) +{ + _tcp_accept(master, status, false); +} + +static void tls_accept(uv_stream_t *master, int status) +{ + _tcp_accept(master, status, true); +} + +int io_listen_tcp(uv_loop_t *loop, uv_tcp_t *handle, int fd, int tcp_backlog, bool has_tls) +{ + const uv_connection_cb connection = has_tls ? tls_accept : tcp_accept; + if (!handle) { + return kr_error(EINVAL); + } + int ret = uv_tcp_init(loop, handle); + if (ret) return ret; + + ret = uv_tcp_open(handle, (uv_os_sock_t) fd); + if (ret) return ret; + + int val; (void)val; + /* TCP_DEFER_ACCEPT delays accepting connections until there is readable data. */ +#ifdef TCP_DEFER_ACCEPT + val = KR_CONN_RTT_MAX/1000; + if (setsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, &val, sizeof(val))) { + kr_log_error("[ io ] listen TCP (defer_accept): %s\n", strerror(errno)); + } +#endif + + ret = uv_listen((uv_stream_t *)handle, tcp_backlog, connection); + if (ret != 0) { + return ret; + } + + /* TCP_FASTOPEN enables 1 RTT connection resumptions. */ +#ifdef TCP_FASTOPEN + #ifdef __linux__ + val = 16; /* Accepts queue length hint */ + #else + val = 1; /* Accepts on/off */ + #endif + if (setsockopt(fd, IPPROTO_TCP, TCP_FASTOPEN, &val, sizeof(val))) { + kr_log_error("[ io ] listen TCP (fastopen): %s\n", strerror(errno)); + } +#endif + + handle->data = NULL; + return 0; +} + +int io_create(uv_loop_t *loop, uv_handle_t *handle, int type, unsigned family, bool has_tls) +{ + int ret = -1; + if (type == SOCK_DGRAM) { + ret = uv_udp_init(loop, (uv_udp_t *)handle); + } else if (type == SOCK_STREAM) { + ret = uv_tcp_init_ex(loop, (uv_tcp_t *)handle, family); + uv_tcp_nodelay((uv_tcp_t *)handle, 1); + } + if (ret != 0) { + return ret; + } + struct session *s = session_new(handle, has_tls); + if (s == NULL) { + ret = -1; + } + return ret; +} + +void io_deinit(uv_handle_t *handle) +{ + if (!handle) { + return; + } + session_free(handle->data); + handle->data = NULL; +} + +void io_free(uv_handle_t *handle) +{ + io_deinit(handle); + free(handle); +} + +int io_start_read(uv_handle_t *handle) +{ + switch (handle->type) { + case UV_UDP: + return uv_udp_recv_start((uv_udp_t *)handle, &handle_getbuf, &udp_recv); + case UV_TCP: + return uv_read_start((uv_stream_t *)handle, &handle_getbuf, &tcp_recv); + default: + assert(!EINVAL); + return kr_error(EINVAL); + } +} + +int io_stop_read(uv_handle_t *handle) +{ + if (handle->type == UV_UDP) { + return uv_udp_recv_stop((uv_udp_t *)handle); + } else { + return uv_read_stop((uv_stream_t *)handle); + } +} diff --git a/utils/watcher/io.h b/utils/watcher/io.h new file mode 100644 index 000000000..f66168f41 --- /dev/null +++ b/utils/watcher/io.h @@ -0,0 +1,47 @@ +/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include +#include +#include +#include "lib/generic/array.h" +#include "worker.h" + +struct tls_ctx_t; +struct tls_client_ctx_t; + +/** Bind address into a file-descriptor (only, no libuv). type is e.g. SOCK_DGRAM */ +int io_bind(const struct sockaddr *addr, int type, const endpoint_flags_t *flags); +/** Initialize a UDP handle and start listening. */ +int io_listen_udp(uv_loop_t *loop, uv_udp_t *handle, int fd); +/** Initialize a TCP handle and start listening. */ +int io_listen_tcp(uv_loop_t *loop, uv_tcp_t *handle, int fd, int tcp_backlog, bool has_tls); + +void tcp_timeout_trigger(uv_timer_t *timer); + +/** Initialize the handle, incl. ->data = struct session * instance. + * \param type = SOCK_* + * \param family = AF_* + * \param has_tls has meanings only when type is SOCK_STREAM */ +int io_create(uv_loop_t *loop, uv_handle_t *handle, int type, + unsigned family, bool has_tls); +void io_deinit(uv_handle_t *handle); +void io_free(uv_handle_t *handle); + +int io_start_read(uv_handle_t *handle); +int io_stop_read(uv_handle_t *handle); diff --git a/utils/watcher/lua/config-watcher.lua b/utils/watcher/lua/config-watcher.lua new file mode 100644 index 000000000..8663999b2 --- /dev/null +++ b/utils/watcher/lua/config-watcher.lua @@ -0,0 +1,5 @@ +-- Open cache if not set/disabled +if not cache.current_size then + cache.size = 100 * MB +end + diff --git a/utils/watcher/lua/meson.build b/utils/watcher/lua/meson.build new file mode 100644 index 000000000..60079d25f --- /dev/null +++ b/utils/watcher/lua/meson.build @@ -0,0 +1,18 @@ +# kres-watcher: lua modules + +sandbox_watcher = configure_file( + input: 'sandbox-watcher.lua.in', + output: 'sandbox-watcher.lua', + configuration: lua_config, +) + +lua_watcher_src = [ + sandbox_watcher, + files('config-watcher.lua'), +] + +# install daemon lua sources +install_data( + lua_watcher_src, + install_dir: lib_dir, +) \ No newline at end of file diff --git a/utils/watcher/lua/sandbox-watcher.lua.in b/utils/watcher/lua/sandbox-watcher.lua.in new file mode 100644 index 000000000..1df086862 --- /dev/null +++ b/utils/watcher/lua/sandbox-watcher.lua.in @@ -0,0 +1,631 @@ +local debug = require('debug') +local ffi = require('ffi') + +-- Units +kB = 1024 +MB = 1024*kB +GB = 1024*MB +-- Time +sec = 1000 +second = sec +minute = 60 * sec +min = minute +hour = 60 * minute +day = 24 * hour + +-- Logging +function panic(fmt, ...) + print(debug.traceback('error occured here (config filename:lineno is ' + .. 'at the bottom, if config is involved):', 2)) + error(string.format('ERROR: '.. fmt, ...), 0) +end +function warn(fmt, ...) + io.stderr:write(string.format(fmt..'\n', ...)) +end +function log(fmt, ...) + print(string.format(fmt, ...)) +end + +-- Resolver bindings +kres = require('kres') +if rawget(kres, 'str2dname') ~= nil then + todname = kres.str2dname +end + +worker.resolve_pkt = function (pkt, options, finish, init) + options = kres.mk_qflags(options) + local task = ffi.C.worker_resolve_start(pkt, options) + + -- Deal with finish and init callbacks + if finish ~= nil then + local finish_cb + finish_cb = ffi.cast('trace_callback_f', + function (req) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + finish(req.answer, req) + finish_cb:free() + end) + task.ctx.req.trace_finish = finish_cb + end + if init ~= nil then + init(task.ctx.req) + end + + return ffi.C.worker_resolve_exec(task, pkt) == 0 +end + +worker.resolve = function (qname, qtype, qclass, options, finish, init) + -- Alternatively use named arguments + if type(qname) == 'table' then + local t = qname + qname = t.name + qtype = t.type + qclass = t.class + options = t.options + finish = t.finish + init = t.init + end + qtype = qtype or kres.type.A + qclass = qclass or kres.class.IN + options = kres.mk_qflags(options) + -- LATER: nicer errors for rubbish in qname, qtype, qclass? + local pkt = ffi.C.worker_resolve_mk_pkt(qname, qtype, qclass, options) + if pkt == nil then + panic('failure in worker.resolve(); probably invalid qname "%s"', qname) + end + local ret = worker.resolve_pkt(pkt, options, finish, init) + ffi.C.knot_rrset_free(pkt.opt_rr, nil); + ffi.C.knot_pkt_free(pkt); + return ret +end +resolve = worker.resolve + +-- Shorthand for aggregated per-worker information +worker.info = function () + local t = worker.stats() + t.pid = worker.pid + return t +end + +-- Resolver mode of operation +local current_mode = 'normal' +local mode_table = { normal=0, strict=1, permissive=2 } +function mode(m) + if not m then return current_mode end + if not mode_table[m] then error('unsupported mode: '..m) end + -- Update current operation mode + current_mode = m + option('STRICT', current_mode == 'strict') + option('PERMISSIVE', current_mode == 'permissive') + return true +end + +-- Trivial option alias +function reorder_RR(val) + return option('REORDER_RR', val) +end + +-- Get/set resolver options via name (string) +function option(name, val) + local flags = kres.context().options; + -- Note: no way to test existence of flags[name] but we want error anyway. + name = string.upper(name) -- convenience + if val ~= nil then + if (val ~= true) and (val ~= false) then + panic('invalid option value: ' .. tostring(val)) + end + flags[name] = val; + end + return flags[name]; +end + +-- Function aliases +-- `env.VAR returns os.getenv(VAR)` +env = {} +setmetatable(env, { + __index = function (_, k) return os.getenv(k) end +}) + +-- Quick access to interfaces +-- `net.` => `net.interfaces()[iface]` +-- `net = {addr1, ..}` => `net.listen(name, addr1)` +-- `net.ipv{4,6} = {true, false}` => enable/disable IPv{4,6} +setmetatable(net, { + __index = function (t, k) + local v = rawget(t, k) + if v then return v + elseif k == 'ipv6' then return not option('NO_IPV6') + elseif k == 'ipv4' then return not option('NO_IPV4') + else return net.interfaces()[k] + end + end, + __newindex = function (t,k,v) + if k == 'ipv6' then return option('NO_IPV6', not v) + elseif k == 'ipv4' then return option('NO_IPV4', not v) + else + local iname = rawget(net.interfaces(), v) + if iname then t.listen(iname) + else t.listen(v) + end + end + end +}) + +-- Syntactic sugar for module loading +-- `modules. = ` +setmetatable(modules, { + __newindex = function (_, k, v) + if type(k) == 'number' then + k, v = v, nil + end + if not rawget(_G, k) then + modules.load(k) + k = string.match(k, '[%w_]+') + local mod = _G[k] + local config = mod and rawget(mod, 'config') + if mod ~= nil and config ~= nil then + if k ~= v then config(v) + else config() + end + end + end + end +}) + +-- Set up lua table for a C module. (Internal function.) +function modules_create_table_for_c(kr_module_ud) + local kr_module = ffi.cast('struct kr_module **', kr_module_ud)[0] + --- Set up the global table named according to the module. + if kr_module.config == nil and kr_module.props == nil then + return + end + local module = {} + local module_name = ffi.string(kr_module.name) + _G[module_name] = module + + --- Construct lua functions for properties. + if kr_module.props ~= nil then + local i = 0 + while true do + local prop = kr_module.props[i] + local cb = prop.cb + if cb == nil then break; end + module[ffi.string(prop.name)] = + function (arg) -- lua wrapper around kr_prop_cb function typedef + local arg_conv + if type(arg) == 'table' or type(arg) == 'boolean' then + arg_conv = tojson(arg) + elseif arg ~= nil then + arg_conv = tostring(arg) + end + local ret_cstr = cb(__engine, kr_module, arg_conv) + if ret_cstr == nil then + return nil + end + -- LATER(optim.): superfluous copying + local ret_str = ffi.string(ret_cstr) + -- This is a bit ugly, but the API is that invalid JSON + -- should be just returned as string :-( + local status, ret = pcall(fromjson, ret_str) + if not status then ret = ret_str end + ffi.C.free(ret_cstr) + return ret + end + i = i + 1 + end + end + + --- Construct lua function for config(). + if kr_module.config ~= nil then + module.config = + function (arg) + local arg_conv + if type(arg) == 'table' or type(arg) == 'boolean' then + arg_conv = tojson(arg) + elseif arg ~= nil then + arg_conv = tostring(arg) + end + return kr_module.config(kr_module, arg_conv) + end + end + + --- Add syntactic sugar for get() and set() properties. + --- That also "catches" any commands like `moduleName.foo = bar`. + local m_index, m_newindex + local get_f = rawget(module, 'get') + if get_f ~= nil then + m_index = function (_, key) + return get_f(key) + end + else + m_index = function () + error('module ' .. module_name .. ' does not support indexing syntax sugar') + end + end + local set_f = rawget(module, 'set') + if set_f ~= nil then + m_newindex = function (_, key, value) + -- This will produce a nasty error on some non-string parameters. + -- Still, we already use it with integer values, e.g. in predict module :-/ + return set_f(key .. ' ' .. value) + end + else + m_newindex = function () + error('module ' .. module_name .. ' does not support assignment syntax sugar') + end + end + setmetatable(module, { + -- note: the two functions only get called for *missing* indices + __index = m_index, + __newindex = m_newindex, + }) +end + +local layer_ctx = ffi.C.kr_layer_t_static +-- Utilities internal for lua layer glue; see ../ffimodule.c +modules_ffi_layer_wrap1 = function (layer_cb) + return layer_cb(layer_ctx.state, layer_ctx.req) +end +modules_ffi_layer_wrap2 = function (layer_cb) + return layer_cb(layer_ctx.state, layer_ctx.req, layer_ctx.pkt) +end +modules_ffi_layer_wrap_checkout = function (layer_cb) + return layer_cb(layer_ctx.state, layer_ctx.req, layer_ctx.pkt, + layer_ctx.dst, layer_ctx.is_stream) +end +modules_ffi_wrap_modcb = function (cb, kr_module_ud) -- this one isn't for layer + local kr_module = ffi.cast('struct kr_module **', kr_module_ud)[0] + return cb(kr_module) +end + +cache.clear = function (name, exact_name, rr_type, chunk_size, callback, prev_state) + if name == nil or (name == '.' and not exact_name) then + -- keep same output format as for 'standard' clear + local total_count = cache.count() + if not cache.clear_everything() then + error('unable to clear everything') + end + return {count = total_count} + end + -- Check parameters, in order, and set defaults if missing. + local dname = kres.str2dname(name) + if not dname then error('cache.clear(): incorrect name passed') end + if exact_name == nil then exact_name = false end + if type(exact_name) ~= 'boolean' + then error('cache.clear(): incorrect exact_name passed') end + + local cach = kres.context().cache; + local rettable = {} + -- Apex warning. If the caller passes a custom callback, + -- we assume they are advanced enough not to need the check. + -- The point is to avoid repeating the check in each callback iteration. + if callback == nil then + local apex_array = ffi.new('knot_dname_t *[1]') -- C: dname **apex_array + local ret = ffi.C.kr_cache_closest_apex(cach, dname, false, apex_array) + if ret < 0 then + error(ffi.string(ffi.C.knot_strerror(ret))) end + if not ffi.C.knot_dname_is_equal(apex_array[0], dname) then + local apex_str = kres.dname2str(apex_array[0]) + rettable.not_apex = 'to clear proofs of non-existence call ' + .. 'cache.clear(\'' .. tostring(apex_str) ..'\')' + rettable.subtree = apex_str + end + ffi.C.free(apex_array[0]) + end + + if rr_type ~= nil then + -- Special case, without any subtree searching. + if not exact_name + then error('cache.clear(): specifying rr_type only supported with exact_name') end + if chunk_size or callback + then error('cache.clear(): chunk_size and callback parameters not supported with rr_type') end + local ret = ffi.C.kr_cache_remove(cach, dname, rr_type) + if ret < 0 then error(ffi.string(ffi.C.knot_strerror(ret))) end + return {count = 1} + end + + if chunk_size == nil then chunk_size = 100 end + if type(chunk_size) ~= 'number' or chunk_size <= 0 + then error('cache.clear(): chunk_size has to be a positive integer') end + + -- Do the C call, and add chunk_size warning. + rettable.count = ffi.C.kr_cache_remove_subtree(cach, dname, exact_name, chunk_size) + if rettable.count == chunk_size then + local msg_extra = '' + if callback == nil then + msg_extra = '; the default callback will continue asynchronously' + end + rettable.chunk_limit = 'chunk size limit reached' .. msg_extra + end + + -- Default callback function: repeat after 1ms + if callback == nil then callback = + function (cbname, cbexact_name, cbrr_type, cbchunk_size, cbself, cbprev_state, cbrettable) + if cbrettable.count < 0 then error(ffi.string(ffi.C.knot_strerror(cbrettable.count))) end + if cbprev_state == nil then cbprev_state = { round = 0 } end + if type(cbprev_state) ~= 'table' + then error('cache.clear() callback: incorrect prev_state passed') end + cbrettable.round = cbprev_state.round + 1 + if (cbrettable.count == cbchunk_size) then + event.after(1, function () + cache.clear(cbname, cbexact_name, cbrr_type, cbchunk_size, cbself, cbrettable) + end) + elseif cbrettable.round > 1 then + log('[cache] asynchonous cache.clear(\'' .. cbname .. '\', ' + .. tostring(cbexact_name) .. ') finished') + end + return cbrettable + end + end + return callback(name, exact_name, rr_type, chunk_size, callback, prev_state, rettable) +end +-- Syntactic sugar for cache +-- `cache[x] -> cache.get(x)` +-- `cache.{size|storage} = value` +setmetatable(cache, { + __index = function (t, k) + local res = rawget(t, k) + if not res and not rawget(t, 'current_size') then return res end + -- Beware: t.get returns empty table on failure to find. + -- That would be confusing here (breaking kresc), so return nil instead. + res = t.get(k) + if res and next(res) ~= nil then return res else return nil end + end, + __newindex = function (t,k,v) + -- Defaults + local storage = rawget(t, 'current_storage') + if not storage then storage = 'lmdb://' end + local size = rawget(t, 'current_size') + if not size then size = 10*MB end + -- Declarative interface for cache + if k == 'size' then t.open(v, storage) + elseif k == 'storage' then t.open(size, v) end + end +}) + +-- Make sandboxed environment +local function make_sandbox(defined) + local __protected = { worker = true, env = true, modules = true, cache = true, net = true, trust_anchors = true } + + -- Compute and export the list of top-level names (hidden otherwise) + local nl = "" + for n in pairs(defined) do + nl = nl .. n .. "\n" + end + + return setmetatable({ __orig_name_list = nl }, { + __index = defined, + __newindex = function (_, k, v) + if __protected[k] then + for k2,v2 in pairs(v) do + defined[k][k2] = v2 + end + else + defined[k] = v + end + end + }) +end + +-- Compatibility sandbox +_G = make_sandbox(getfenv(0)) +setfenv(0, _G) + +-- Load default modules +--trust_anchors = require('trust_anchors') +--modules.load('ta_update') +--modules.load('ta_signal_query') +--modules.load('policy') +--modules.load('priming') +--modules.load('detect_time_skew') +--modules.load('detect_time_jump') +--modules.load('ta_sentinel') +--modules.load('edns_keepalive') +--modules.load('refuse_nord') +--modules.load('watchdog') + +modules.load('prefill') +prefill.config({ + ['.'] = { + url = 'https://www.internic.net/domain/root.zone', + interval = 86400, + ca_file = '/etc/pki/tls/certs/ca-bundle.crt', + } +}) + +-- Load keyfile_default +-- trust_anchors.add_file('@keyfile_default@', @unmanaged@) + +-- Interactive command evaluation +function eval_cmd(line, raw) + -- Compatibility sandbox code loading + local function load_code(code) + if getfenv then -- Lua 5.1 + return loadstring(code) + else -- Lua 5.2+ + return load(code, nil, 't', _ENV) + end + end + local err, chunk + chunk, err = load_code(raw and 'return '..line or 'return table_print('..line..')') + if err then + chunk, err = load_code(line) + end + if not err then + return chunk() + else + error(err) + end +end + +-- Pretty printing + +local function funcsign(f) +-- thanks to AnandA777 from StackOverflow! Function funcsign is adapted version of +-- https://stackoverflow.com/questions/51095022/inspect-function-signature-in-lua-5-1 + assert(type(f) == 'function', "bad argument #1 to 'funcsign' (function expected)") + local debuginfo = debug.getinfo(f) + if debuginfo.what == 'C' then -- names N/A + return '(?)' + end + + local func_args = {} + pcall(function() + local oldhook + local delay = 2 + local function hook() + delay = delay - 1 + if delay == 0 then -- call this only for the introspected function + -- stack depth 2 is the introspected function + for i = 1, debuginfo.nparams do + local k = debug.getlocal(2, i) + table.insert(func_args, k) + end + if debuginfo.isvararg then + table.insert(func_args, "...") + end + debug.sethook(oldhook) + error('aborting the call to introspected function') + end + end + oldhook = debug.sethook(hook, "c") -- invoke hook() on function call + f(unpack({})) -- huh? + end) + return "(" .. table.concat(func_args, ", ") .. ")" +end + +function table_print (tt, indent, done) + done = done or {} + indent = indent or 0 + local result = "" + -- Ordered for-iterator for tables with tostring-able keys. + local function ordered_iter(unordered_tt) + local keys = {} + for k in pairs(unordered_tt) do + table.insert(keys, k) + end + table.sort(keys, function (a, b) return tostring(a) < tostring(b) end) + local i = 0 + return function() + i = i + 1 + if keys[i] then + return keys[i], unordered_tt[keys[i]] + end + end + end + -- Convert to printable string (escape unprintable) + local function printable(value) + value = tostring(value) + local bytes = {} + for i = 1, #value do + local c = string.byte(value, i) + if c >= 0x20 and c < 0x7f then table.insert(bytes, string.char(c)) + else table.insert(bytes, '\\'..tostring(c)) + end + if i > 80 then table.insert(bytes, '...') break end + end + return table.concat(bytes) + end + if type(tt) == "table" then + for key, value in ordered_iter(tt) do + result = result .. string.rep (" ", indent) + if type (value) == "table" and not done [value] then + done [value] = true + result = result .. string.format("[%s] => {\n", printable (key)) + result = result .. table_print (value, indent + 4, done) + result = result .. string.rep (" ", indent) + result = result .. "}\n" + elseif type (value) == "function" then + result = result .. string.format("[%s] => function %s%s: %s\n", + tostring(key), tostring(key), funcsign(value), + string.sub(tostring(value), 11)) + else + result = result .. string.format("[%s] => %s\n", + tostring (key), printable(value)) + end + end + else -- not a table + local tt_str + if type(tt) == "function" then + tt_str = string.format("function%s: %s\n", funcsign(tt), + string.sub(tostring(tt), 11)) + else + tt_str = tostring(tt) + end + result = result .. tt_str .. "\n" + end + return result +end + +-- This extends the worker module to allow asynchronous execution of functions and nonblocking I/O. +-- The current implementation combines cqueues for Lua interface, and event.socket() in order to not +-- block resolver engine while waiting for I/O or timers. +-- +local has_cqueues, cqueues = pcall(require, 'cqueues') +if has_cqueues then + + -- Export the asynchronous sleep function + worker.sleep = cqueues.sleep + + -- Create metatable for workers to define the API + -- It can schedule multiple cqueues and yield execution when there's a wait for blocking I/O or timer + local asynchronous_worker_mt = { + work = function (self) + local ok, err, _, co = self.cq:step(0) + if not ok then + warn('[%s] error: %s %s', self.name or 'worker', err, debug.traceback(co)) + end + -- Reschedule timeout or create new one + local timeout = self.cq:timeout() + if timeout then + -- Throttle timeouts to avoid too frequent wakeups + if timeout == 0 then timeout = 0.00001 end + -- Convert from seconds to duration + timeout = timeout * sec + if not self.next_timeout then + self.next_timeout = event.after(timeout, self.on_step) + else + event.reschedule(self.next_timeout, timeout) + end + else -- Cancel running timeout when there is no next deadline + if self.next_timeout then + event.cancel(self.next_timeout) + self.next_timeout = nil + end + end + end, + wrap = function (self, f) + self.cq:wrap(f) + end, + loop = function (self) + self.on_step = function () self:work() end + self.event_fd = event.socket(self.cq:pollfd(), self.on_step) + end, + close = function (self) + if self.event_fd then + event.cancel(self.event_fd) + self.event_fd = nil + end + end, + } + + -- Implement the coroutine worker with cqueues + local function worker_new (name) + return setmetatable({name = name, cq = cqueues.new()}, { __index = asynchronous_worker_mt }) + end + + -- Create a default background worker + worker.bg_worker = worker_new('worker.background') + worker.bg_worker:loop() + + -- Wrap a function for asynchronous execution + function worker.coroutine (f) + worker.bg_worker:wrap(f) + end +else + -- Disable asynchronous execution + local function disabled () error('cqueues are required for asynchronous execution') end + worker.sleep = disabled + worker.map = disabled + worker.coroutine = disabled +end diff --git a/utils/watcher/main.c b/utils/watcher/main.c new file mode 100644 index 000000000..f50c09b6b --- /dev/null +++ b/utils/watcher/main.c @@ -0,0 +1,938 @@ +/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "kresconfig.h" + +#include "contrib/ccan/asprintf/asprintf.h" +#include "contrib/cleanup.h" +#include "contrib/ucw/mempool.h" +#include "engine.h" +#include "io.h" +#include "network.h" +#include "tls.h" +#include "udp_queue.h" +#include "worker.h" +#include "lib/defines.h" +#include "lib/dnssec.h" +#include "lib/dnssec/ta.h" +#include "lib/resolve.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "watcher.h" +#include "modules/sysrepo/common/sysrepo.h" +#include "sr_subscriptions.h" + +#ifdef ENABLE_CAP_NG +#include +#endif + +#include +#include +#if SYSTEMD_VERSION > 0 +#include +#endif +#include + + +/* @internal Array of ip address shorthand. */ +typedef array_t(char*) addr_array_t; + +typedef array_t(const char*) config_array_t; + +typedef struct { + int fd; + endpoint_flags_t flags; /**< .sock_type isn't meaningful here */ +} flagged_fd_t; +typedef array_t(flagged_fd_t) flagged_fd_array_t; + +struct args { + addr_array_t addrs, addrs_tls; + flagged_fd_array_t fds; + int control_fd; + int forks; + config_array_t config; + const char *rundir; + bool interactive; + bool quiet; + bool tty_binary_output; +}; + +/** + * TTY control: process input and free() the buffer. + * + * For parameters see http://docs.libuv.org/en/v1.x/stream.html#c.uv_read_cb + * + * - This is just basic read-eval-print; libedit is supported through kresc; + * - stream->data contains program arguments (struct args); + */ +static void tty_process_input(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) +{ + char *cmd = buf ? buf->base : NULL; /* To be free()d on return. */ + + /* Set output streams */ + FILE *out = stdout; + uv_os_fd_t stream_fd = 0; + struct args *args = stream->data; + if (uv_fileno((uv_handle_t *)stream, &stream_fd)) { + uv_close((uv_handle_t *)stream, (uv_close_cb) free); + free(cmd); + return; + } + if (stream_fd != STDIN_FILENO) { + if (nread < 0) { /* Close if disconnected */ + uv_close((uv_handle_t *)stream, (uv_close_cb) free); + } + if (nread <= 0) { + free(cmd); + return; + } + uv_os_fd_t dup_fd = dup(stream_fd); + if (dup_fd >= 0) { + out = fdopen(dup_fd, "w"); + } + } + + /* Execute */ + if (stream && cmd && nread > 0) { + /* Ensure cmd is 0-terminated */ + if (cmd[nread - 1] == '\n') { + cmd[nread - 1] = '\0'; + } else { + if (nread >= buf->len) { /* only equality should be possible */ + char *newbuf = realloc(cmd, nread + 1); + if (!newbuf) + goto finish; + cmd = newbuf; + } + cmd[nread] = '\0'; + } + + /* Pseudo-command for switching to "binary output"; */ + if (strcmp(cmd, "__binary") == 0) { + args->tty_binary_output = true; + goto finish; + } + + lua_State *L = the_worker->engine->L; + int ret = engine_cmd(L, cmd, false); + const char *message = ""; + if (lua_gettop(L) > 0) { + message = lua_tostring(L, -1); + } + + /* Simpler output in binary mode */ + if (args->tty_binary_output) { + size_t len_s = strlen(message); + if (len_s > UINT32_MAX) + goto finish; + uint32_t len_n = htonl(len_s); + fwrite(&len_n, sizeof(len_n), 1, out); + fwrite(message, len_s, 1, out); + lua_settop(L, 0); + goto finish; + } + + /* Log to remote socket if connected */ + const char *delim = args->quiet ? "" : "> "; + if (stream_fd != STDIN_FILENO) { + if (VERBOSE_STATUS) + fprintf(stdout, "%s\n", cmd); /* Duplicate command to logs */ + if (message) + fprintf(out, "%s", message); /* Duplicate output to sender */ + if (message || !args->quiet) + fprintf(out, "\n"); + fprintf(out, "%s", delim); + } + if (stream_fd == STDIN_FILENO || VERBOSE_STATUS) { + /* Log to standard streams */ + FILE *fp_out = ret ? stderr : stdout; + if (message) + fprintf(fp_out, "%s", message); + if (message || !args->quiet) + fprintf(fp_out, "\n"); + fprintf(fp_out, "%s", delim); + } + lua_settop(L, 0); + } +finish: + free(cmd); + /* Close if redirected */ + if (stream_fd != STDIN_FILENO) { + fclose(out); + } +} + +static void tty_alloc(uv_handle_t *handle, size_t suggested, uv_buf_t *buf) { + buf->len = suggested; + buf->base = malloc(suggested); +} + +static void tty_accept(uv_stream_t *master, int status) +{ + uv_tcp_t *client = malloc(sizeof(*client)); + struct args *args = master->data; + if (client) { + uv_tcp_init(master->loop, client); + if (uv_accept(master, (uv_stream_t *)client) != 0) { + free(client); + return; + } + client->data = args; + uv_read_start((uv_stream_t *)client, tty_alloc, tty_process_input); + /* Write command line */ + if (!args->quiet) { + uv_buf_t buf = { "> ", 2 }; + uv_try_write((uv_stream_t *)client, &buf, 1); + } + } +} + +/* @internal AF_LOCAL reads may still be interrupted, loop it. */ +static bool ipc_readall(int fd, char *dst, size_t len) +{ + while (len > 0) { + int rb = read(fd, dst, len); + if (rb > 0) { + dst += rb; + len -= rb; + } else if (errno != EAGAIN && errno != EINTR) { + return false; + } + } + return true; +} + +static void ipc_activity(uv_poll_t *handle, int status, int events) +{ + struct engine *engine = handle->data; + if (status != 0) { + kr_log_error("[system] ipc: %s\n", uv_strerror(status)); + return; + } + /* Get file descriptor from handle */ + uv_os_fd_t fd = 0; + (void) uv_fileno((uv_handle_t *)(handle), &fd); + /* Read expression from IPC pipe */ + uint32_t len = 0; + auto_free char *rbuf = NULL; + if (!ipc_readall(fd, (char *)&len, sizeof(len))) { + goto failure; + } + if (len < UINT32_MAX) { + rbuf = malloc(len + 1); + } else { + errno = EINVAL; + } + if (!rbuf) { + goto failure; + } + if (!ipc_readall(fd, rbuf, len)) { + goto failure; + } + rbuf[len] = '\0'; + /* Run expression */ + const char *message = ""; + int ret = engine_ipc(engine, rbuf); + if (ret > 0) { + message = lua_tostring(engine->L, -1); + } + /* Clear the Lua stack */ + lua_settop(engine->L, 0); + /* Send response back */ + len = strlen(message); + if (write(fd, &len, sizeof(len)) != sizeof(len) || + write(fd, message, len) != len) { + goto failure; + } + return; /* success! */ +failure: + /* Note that if the piped command got read or written partially, + * we would get out of sync and only receive rubbish now. + * Therefore we prefer to stop IPC, but we try to continue with all else. + */ + kr_log_error("[system] stopping ipc because of: %s\n", strerror(errno)); + uv_poll_stop(handle); + uv_close((uv_handle_t *)handle, (uv_close_cb)free); +} + +static bool ipc_watch(uv_loop_t *loop, struct engine *engine, int fd) +{ + uv_poll_t *poller = malloc(sizeof(*poller)); + if (!poller) { + return false; + } + int ret = uv_poll_init(loop, poller, fd); + if (ret != 0) { + free(poller); + return false; + } + poller->data = engine; + ret = uv_poll_start(poller, UV_READABLE, ipc_activity); + if (ret != 0) { + free(poller); + return false; + } + /* libuv sets O_NONBLOCK whether we want it or not */ + (void) fcntl(fd, F_SETFD, fcntl(fd, F_GETFL) & ~O_NONBLOCK); + return true; +} + +static void signal_handler(uv_signal_t *handle, int signum) +{ + uv_stop(uv_default_loop()); + uv_signal_stop(handle); +} + +/** SIGBUS -> attempt to remove the overflowing cache file and abort. */ +static void sigbus_handler(int sig, siginfo_t *siginfo, void *ptr) +{ + /* We can't safely assume that printf-like functions work, but write() is OK. + * See POSIX for the safe functions, e.g. 2017 version just above this link: + * http://pubs.opengroup.org/onlinepubs/9699919799/functions/V2_chap02.html#tag_15_04_04 + */ + #define WRITE_ERR(err_charray) \ + (void)write(STDERR_FILENO, err_charray, sizeof(err_charray)) + /* Unfortunately, void-cast on the write isn't enough to avoid the warning. */ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-result" + const char msg_typical[] = + "\nSIGBUS received; this is most likely due to filling up the filesystem where cache resides.\n", + msg_unknown[] = "\nSIGBUS received, cause unknown.\n", + msg_deleted[] = "Cache file deleted.\n", + msg_del_fail[] = "Cache file deletion failed.\n", + msg_final[] = "kresd can not recover reliably by itself, exiting.\n"; + if (siginfo->si_code != BUS_ADRERR) { + WRITE_ERR(msg_unknown); + goto end; + } + WRITE_ERR(msg_typical); + if (!kr_cache_emergency_file_to_remove) goto end; + if (unlink(kr_cache_emergency_file_to_remove)) { + WRITE_ERR(msg_del_fail); + } else { + WRITE_ERR(msg_deleted); + } +end: + WRITE_ERR(msg_final); + _exit(128 - sig); /*< regular return from OS-raised SIGBUS can't work anyway */ + #undef WRITE_ERR + #pragma GCC diagnostic pop +} + + +/* + * Server operation. + */ + +static int fork_workers(fd_array_t *ipc_set, int forks) +{ + /* Fork subprocesses if requested */ + while (--forks > 0) { + int sv[2] = {-1, -1}; + if (socketpair(AF_LOCAL, SOCK_STREAM, 0, sv) < 0) { + perror("[system] socketpair"); + return kr_error(errno); + } + int pid = fork(); + if (pid < 0) { + perror("[system] fork"); + return kr_error(errno); + } + + /* Forked process */ + if (pid == 0) { + array_clear(*ipc_set); + array_push(*ipc_set, sv[0]); + close(sv[1]); + return forks; + /* Parent process */ + } else { + array_push(*ipc_set, sv[1]); + /* Do not share parent-end with other forks. */ + (void) fcntl(sv[1], F_SETFD, FD_CLOEXEC); + close(sv[0]); + } + } + return 0; +} + +static void help(int argc, char *argv[]) +{ + printf("Usage: %s [parameters] [rundir]\n", argv[0]); + printf("\nParameters:\n" + " -c, --config=[path] Config file path (relative to [rundir]) (default: config).\n" + " -q, --quiet No command prompt in interactive mode.\n" + " -v, --verbose Run in verbose mode." +#ifdef NOVERBOSELOG + " (Recompile without -DNOVERBOSELOG to activate.)" +#endif + "\n" + " -V, --version Print version of the server.\n" + " -h, --help Print help and usage.\n" + "Options:\n" + " [rundir] Path to the working directory (default: .)\n"); +} + +/** \return exit code for main() */ +static int run_worker(uv_loop_t *loop, struct engine *engine, fd_array_t *ipc_set, bool leader, struct args *args) +{ + /* Start sysrepo context */ + sysrepo_uv_ctx_t *sysrepo = engine->watcher.sysrepo; + int ret = 0; + if (!ret) ret = sysrepo_subscr_register(sysrepo->session, &sysrepo->subscription); + if (!ret) ret = sysrepo_ctx_start(loop, sysrepo); + if (ret) { + kr_log_error("[sysrepo] failed to subscribe for changes: %s\n", sr_strerror(ret)); + } + + /* Start Processes if Enabled */ + //resolver_start(); + + /* Only some kinds of stdin work with uv_pipe_t. + * Otherwise we would abort() from libuv e.g. with interactive) switch (uv_guess_handle(0)) { + case UV_TTY: /* standard terminal */ + /* TODO: it has worked OK so far, but we'd better use uv_tty_* + * for this case instead of uv_pipe_*. */ + case UV_NAMED_PIPE: /* echo 'quit()' | kresd ... */ + break; + default: + kr_log_error( + "[system] error: standard input is not a terminal or pipe; " + "use '-f 1' if you want non-interactive mode. " + "Commands can be simply added to your configuration file or sent over the tty/$PID control socket.\n" + ); + return EXIT_FAILURE; + } + + if (setvbuf(stdout, NULL, _IONBF, 0) || setvbuf(stderr, NULL, _IONBF, 0)) { + kr_log_error("[system] failed to to set output buffering (ignored): %s\n", + strerror(errno)); + fflush(stderr); + } + + /* Control sockets or TTY */ + auto_free char *sock_file = NULL; + uv_pipe_t pipe; + uv_pipe_init(loop, &pipe, 0); + pipe.data = args; + if (args->interactive) { + if (!args->quiet) + printf("[system] interactive mode\n> "); + uv_pipe_open(&pipe, 0); + uv_read_start((uv_stream_t*) &pipe, tty_alloc, tty_process_input); + } else { + int pipe_ret = -1; + if (args->control_fd != -1) { + pipe_ret = uv_pipe_open(&pipe, args->control_fd); + } else { + (void) mkdir("tty", S_IRWXU|S_IRWXG); + sock_file = afmt("tty/%ld", (long)getpid()); + if (sock_file) { + pipe_ret = uv_pipe_bind(&pipe, sock_file); + } + } + if (!pipe_ret) + uv_listen((uv_stream_t *) &pipe, 16, tty_accept); + } + /* Watch IPC pipes (or just assign them if leading the pgroup). */ + if (!leader) { + for (size_t i = 0; i < ipc_set->len; ++i) { + if (!ipc_watch(loop, engine, ipc_set->at[i])) { + kr_log_error("[system] failed to create poller: %s\n", strerror(errno)); + close(ipc_set->at[i]); + } + } + } + memcpy(&engine->ipc_set, ipc_set, sizeof(*ipc_set)); + + /* Notify supervisor. */ +#if SYSTEMD_VERSION > 0 + sd_notify(0, "READY=1"); +#endif + /* Run event loop */ + uv_run(loop, UV_RUN_DEFAULT); + if (sock_file) { + unlink(sock_file); + } + uv_close((uv_handle_t *)&pipe, NULL); /* Seems OK even on the stopped loop. */ + return EXIT_SUCCESS; +} + +#if SYSTEMD_VERSION >= 227 +static void free_sd_socket_names(char **socket_names, int count) +{ + for (int i = 0; i < count; i++) { + free(socket_names[i]); + } + free(socket_names); +} +#endif + +static void args_init(struct args *args) +{ + memset(args, 0, sizeof(struct args)); + /* Zeroed arrays are OK. */ + args->forks = 1; + args->control_fd = -1; + args->interactive = true; + args->quiet = false; +} + +/* Free pointed-to resources. */ +static void args_deinit(struct args *args) +{ + array_clear(args->addrs); + array_clear(args->addrs_tls); + for (int i = 0; i < args->fds.len; ++i) + free_const(args->fds.at[i].flags.kind); + array_clear(args->fds); + array_clear(args->config); +} + +static long strtol_10(const char *s) +{ + if (!s) abort(); + /* ^^ This shouldn't ever happen. When getopt_long() returns an option + * character that has a mandatory parameter, optarg can't be NULL. */ + return strtol(s, NULL, 10); +} + +/** Process arguments into struct args. + * @return >=0 if main() should be exited immediately. + */ +static int parse_args(int argc, char **argv, struct args *args) +{ + /* Long options. */ + int c = 0, li = 0; + struct option opts[] = { + {"addr", required_argument, 0, 'a'}, + {"tls", required_argument, 0, 't'}, + {"fd", required_argument, 0, 'S'}, + {"config", required_argument, 0, 'c'}, + {"forks", required_argument, 0, 'f'}, + {"verbose", no_argument, 0, 'v'}, + {"quiet", no_argument, 0, 'q'}, + {"version", no_argument, 0, 'V'}, + {"help", no_argument, 0, 'h'}, + {0, 0, 0, 0} + }; + while ((c = getopt_long(argc, argv, "a:t:S:c:f:m:K:k:vqVh", opts, &li)) != -1) { + switch (c) + { + case 'a': + array_push(args->addrs, optarg); + break; + case 't': + array_push(args->addrs_tls, optarg); + break; + case 'c': + assert(optarg != NULL); + array_push(args->config, optarg); + break; + case 'f': + args->interactive = false; + args->forks = strtol_10(optarg); + if (args->forks <= 0) { + kr_log_error("[system] error '-f' requires a positive" + " number, not '%s'\n", optarg); + return EXIT_FAILURE; + } + break; + case 'v': + kr_verbose_set(true); +#ifdef NOVERBOSELOG + kr_log_info("--verbose flag has no effect due to compilation with -DNOVERBOSELOG.\n"); +#endif + break; + case 'q': + args->quiet = true; + break; + case 'V': + kr_log_info("%s, version %s\n", "Knot Resolver", PACKAGE_VERSION); + return EXIT_SUCCESS; + case 'h': + case '?': + help(argc, argv); + return EXIT_SUCCESS; + default: + help(argc, argv); + return EXIT_FAILURE; + case 'S': + (void)0; + flagged_fd_t ffd = { 0 }; + char *endptr; + ffd.fd = strtol(optarg, &endptr, 10); + if (endptr != optarg && endptr[0] == '\0') { + /* Plain DNS */ + ffd.flags.tls = false; + } else if (endptr[0] == ':' && strcasecmp(endptr + 1, "tls") == 0) { + /* DoT */ + ffd.flags.tls = true; + /* We know what .sock_type should be but it wouldn't help. */ + } else if (endptr[0] == ':' && endptr[1] != '\0') { + /* Some other kind; no checks here. */ + ffd.flags.kind = strdup(endptr + 1); + } else { + kr_log_error("[system] incorrect value passed to '-S/--fd': %s\n", + optarg); + return EXIT_FAILURE; + } + array_push(args->fds, ffd); + break; + } + } + if (optind < argc) { + args->rundir = argv[optind]; + } + return -1; +} + +/** Just convert addresses to file-descriptors; clear *addrs on success. + * @note AF_UNIX is supported (starting with '/'). + * @return zero or exit code for main() + */ +static int bind_sockets(addr_array_t *addrs, bool tls, flagged_fd_array_t *fds) +{ + bool has_error = false; + for (size_t i = 0; i < addrs->len; ++i) { + /* Get port and separate address string. */ + uint16_t port = tls ? KR_DNS_TLS_PORT : KR_DNS_PORT; + char addr_buf[INET6_ADDRSTRLEN + 1]; + int ret; + const char *addr_str; + const int family = kr_straddr_family(addrs->at[i]); + if (family == AF_UNIX) { + ret = 0; + addr_str = addrs->at[i]; + } else { /* internet socket (or garbage) */ + ret = kr_straddr_split(addrs->at[i], addr_buf, &port); + addr_str = addr_buf; + } + /* Get sockaddr. */ + struct sockaddr *sa = NULL; + if (ret == 0) { + sa = kr_straddr_socket(addr_str, port, NULL); + if (!sa) ret = kr_error(EINVAL); /* could be ENOMEM but unlikely */ + } + flagged_fd_t ffd = { .flags = { .tls = tls } }; + if (ret == 0 && !tls && family != AF_UNIX) { + /* AF_UNIX can do SOCK_DGRAM, but let's not support that *here*. */ + ffd.fd = io_bind(sa, SOCK_DGRAM, NULL); + if (ffd.fd < 0) + ret = ffd.fd; + else if (array_push(*fds, ffd) < 0) + ret = kr_error(ENOMEM); + } + if (ret == 0) { /* common for TCP and TLS, including AF_UNIX cases */ + ffd.fd = io_bind(sa, SOCK_STREAM, NULL); + if (ffd.fd < 0) + ret = ffd.fd; + else if (array_push(*fds, ffd) < 0) + ret = kr_error(ENOMEM); + } + free(sa); + if (ret != 0) { + kr_log_error("[system] bind to '%s'%s: %s\n", + addrs->at[i], tls ? " (TLS)" : "", kr_strerror(ret)); + has_error = true; + } + } + array_clear(*addrs); + return has_error ? EXIT_FAILURE : kr_ok(); +} + +static int start_listening(struct network *net, flagged_fd_array_t *fds) { + int some_bad_ret = 0; + for (size_t i = 0; i < fds->len; ++i) { + flagged_fd_t *ffd = &fds->at[i]; + int ret = network_listen_fd(net, ffd->fd, ffd->flags); + if (ret != 0) { + some_bad_ret = ret; + /* TODO: try logging address@port. It's not too important, + * because typical problems happen during binding already. + * (invalid address, permission denied) */ + kr_log_error("[system] listen on fd=%d: %s\n", + ffd->fd, kr_strerror(ret)); + /* Continue printing all of these before exiting. */ + } else { + ffd->flags.kind = NULL; /* ownership transferred */ + } + } + return some_bad_ret; +} + +/* Drop POSIX 1003.1e capabilities. */ +// static void drop_capabilities(void) +// { +// #ifdef ENABLE_CAP_NG +// /* Drop all capabilities. */ +// if (capng_have_capability(CAPNG_EFFECTIVE, CAP_SETPCAP)) { +// capng_clear(CAPNG_SELECT_BOTH); + +// /* Apply. */ +// if (capng_apply(CAPNG_SELECT_BOTH) < 0) { +// kr_log_error("[system] failed to set process capabilities: %s\n", +// strerror(errno)); +// } +// } else { +// kr_log_info("[system] process not allowed to set capabilities, skipping\n"); +// } +// #endif /* ENABLE_CAP_NG */ +// } + +int main(int argc, char **argv) +{ + struct args args; + args_init(&args); + int ret = parse_args(argc, argv, &args); + if (ret >= 0) goto cleanup_args; + + ret = bind_sockets(&args.addrs, false, &args.fds); + if (ret) goto cleanup_args; + ret = bind_sockets(&args.addrs_tls, true, &args.fds); + if (ret) goto cleanup_args; + +#if SYSTEMD_VERSION >= 227 + /* Accept passed sockets from systemd supervisor. */ + char **socket_names = NULL; + int sd_nsocks = sd_listen_fds_with_names(0, &socket_names); + if (sd_nsocks < 0) { + kr_log_error("[system] failed passing sockets from systemd: %s\n", + kr_strerror(sd_nsocks)); + free_sd_socket_names(socket_names, sd_nsocks); + ret = EXIT_FAILURE; + goto cleanup_args; + } + if (sd_nsocks > 0 && args.forks != 1) { + kr_log_error("[system] when run under systemd-style supervision, " + "use single-process only (bad: --forks=%d).\n", args.forks); + free_sd_socket_names(socket_names, sd_nsocks); + ret = EXIT_FAILURE; + goto cleanup_args; + } + for (int i = 0; i < sd_nsocks; ++i) { + /* when run under systemd supervision, do not use interactive mode */ + args.interactive = false; + flagged_fd_t ffd = { .fd = SD_LISTEN_FDS_START + i }; + + if (!strcasecmp("control", socket_names[i])) { + if (args.control_fd != -1) { + kr_log_error("[system] multiple control sockets passed from systemd\n"); + ret = EXIT_FAILURE; + break; + } + args.control_fd = ffd.fd; + free(socket_names[i]); + } else { + if (!strcasecmp("dns", socket_names[i])) { + free(socket_names[i]); + } else if (!strcasecmp("tls", socket_names[i])) { + ffd.flags.tls = true; + free(socket_names[i]); + } else { + ffd.flags.kind = socket_names[i]; + } + array_push(args.fds, ffd); + } + /* Either freed or passed ownership. */ + socket_names[i] = NULL; + } + free_sd_socket_names(socket_names, sd_nsocks); + if (ret) goto cleanup_args; +#endif + + /* Switch to rundir. */ + if (args.rundir != NULL) { + /* FIXME: access isn't a good way if we start as root and drop privileges later */ + if (access(args.rundir, W_OK) != 0 + || chdir(args.rundir) != 0) { + kr_log_error("[system] rundir '%s': %s\n", + args.rundir, strerror(errno)); + return EXIT_FAILURE; + } + } + + /* Select which config files to load and verify they are read-able. */ + bool load_defaults = true; + size_t i = 0; + while (i < args.config.len) { + const char *config = args.config.at[i]; + if (strcmp(config, "-") == 0) { + load_defaults = false; + array_del(args.config, i); + continue; /* don't increment i */ + } else if (access(config, R_OK) != 0) { + char cwd[PATH_MAX]; + get_workdir(cwd, sizeof(cwd)); + kr_log_error("[system] config '%s' (workdir '%s'): %s\n", + config, cwd, strerror(errno)); + return EXIT_FAILURE; + } + i++; + } + if (args.config.len == 0 && access("config", R_OK) == 0) + array_push(args.config, "config"); + if (load_defaults) + array_push(args.config, LIBDIR "/config-watcher.lua"); + + /* File-descriptor count limit: soft->hard. */ + struct rlimit rlim; + ret = getrlimit(RLIMIT_NOFILE, &rlim); + if (ret == 0 && rlim.rlim_cur != rlim.rlim_max) { + kr_log_verbose("[system] increasing file-descriptor limit: %ld -> %ld\n", + (long)rlim.rlim_cur, (long)rlim.rlim_max); + rlim.rlim_cur = rlim.rlim_max; + ret = setrlimit(RLIMIT_NOFILE, &rlim); + } + if (ret) { + kr_log_error("[system] failed to get or set file-descriptor limit: %s\n", + strerror(errno)); + } else if (rlim.rlim_cur < 512*1024) { + kr_log_info("[system] warning: hard limit for number of file-descriptors is only %ld but recommended value is 524288\n", + rlim.rlim_cur); + } + + /* Connect forks with local socket */ + fd_array_t ipc_set; + array_init(ipc_set); + /* Fork subprocesses if requested */ + int fork_id = fork_workers(&ipc_set, args.forks); + if (fork_id < 0) { + return EXIT_FAILURE; + } + + //kr_crypto_init(); + + /* Create a server engine. */ + knot_mm_t pool = { + .ctx = mp_new (4096), + .alloc = (knot_mm_alloc_t) mp_alloc + }; + /** Static to work around lua_pushlightuserdata() limitations. + * TODO: convert to a proper singleton like worker, most likely. */ + static struct engine engine; + ret = engine_init(&engine, &pool); + if (ret != 0) { + kr_log_error("[system] failed to initialize engine: %s\n", kr_strerror(ret)); + return EXIT_FAILURE; + } + /* Initialize the worker */ + ret = worker_init(&engine, fork_id, args.forks); + if (ret != 0) { + kr_log_error("[system] failed to initialize worker: %s\n", kr_strerror(ret)); + return EXIT_FAILURE; + } + + uv_loop_t *loop = uv_default_loop(); + /* Catch some signals. */ + uv_signal_t sigint, sigterm; + if (true) ret = uv_signal_init(loop, &sigint); + if (!ret) ret = uv_signal_init(loop, &sigterm); + if (!ret) ret = uv_signal_start(&sigint, signal_handler, SIGINT); + if (!ret) ret = uv_signal_start(&sigterm, signal_handler, SIGTERM); + /* Block SIGPIPE; see https://github.com/libuv/libuv/issues/45 */ + if (!ret && signal(SIGPIPE, SIG_IGN) == SIG_ERR) ret = errno; + if (!ret) { + /* Catching SIGBUS via uv_signal_* can't work; see: + * https://github.com/libuv/libuv/pull/1987 */ + struct sigaction sa; + memset(&sa, 0, sizeof(sa)); + sa.sa_sigaction = sigbus_handler; + sa.sa_flags = SA_SIGINFO; + if (sigaction(SIGBUS, &sa, NULL)) { + ret = errno; + } + } + if (ret) { + kr_log_error("[system] failed to set up signal handlers: %s\n", + strerror(abs(errno))); + ret = EXIT_FAILURE; + goto cleanup; + } + /* Profiling: avoid SIGPROF waking up the event loop. Otherwise the profiles + * (of the usual type) may skew results, e.g. epoll_pwait() taking lots of time. */ + ret = uv_loop_configure(loop, UV_LOOP_BLOCK_SIGNAL, SIGPROF); + if (ret) { + kr_log_info("[system] failed to block SIGPROF in event loop, ignoring: %s\n", + uv_strerror(ret)); + } + + /* Start listening, in the sense of network_listen_fd(). */ + // if (start_listening(&engine.net, &args.fds) != 0) { + // ret = EXIT_FAILURE; + // goto cleanup; + // } + + // ret = udp_queue_init_global(loop); + // if (ret) { + // kr_log_error("[system] failed to initialize UDP queue: %s\n", + // kr_strerror(ret)); + // ret = EXIT_FAILURE; + // goto cleanup; + // } + + /* Start the scripting engine */ + if (engine_load_sandbox(&engine) != 0) { + ret = EXIT_FAILURE; + goto cleanup; + } + + for (i = 0; i < args.config.len; ++i) { + const char *config = args.config.at[i]; + if (engine_loadconf(&engine, config) != 0) { + ret = EXIT_FAILURE; + goto cleanup; + } + lua_settop(engine.L, 0); + } + + //drop_capabilities(); + + if (engine_start(&engine) != 0) { + ret = EXIT_FAILURE; + goto cleanup; + } + + // if (network_engage_endpoints(&engine.net)) { + // ret = EXIT_FAILURE; + // goto cleanup; + // } + + /* Run the event loop */ + ret = run_worker(loop, &engine, &ipc_set, fork_id == 0, &args); + +cleanup:/* Cleanup. */ + engine_deinit(&engine); + worker_deinit(); + if (loop != NULL) { + uv_loop_close(loop); + } + mp_delete(pool.ctx); +cleanup_args: + args_deinit(&args); + //kr_crypto_cleanup(); + return ret; +} diff --git a/utils/watcher/meson.build b/utils/watcher/meson.build new file mode 100644 index 000000000..337f061e5 --- /dev/null +++ b/utils/watcher/meson.build @@ -0,0 +1,56 @@ +# kres-watcher + +watcher_src = files([ + 'bindings/cache.c', + 'bindings/event.c', + 'bindings/impl.c', + 'bindings/modules.c', + 'bindings/net.c', + 'bindings/worker.c', + 'engine.c', + 'ffimodule.c', + 'io.c', + 'main.c', + 'network.c', + 'session.c', + 'tls.c', + 'tls_ephemeral_credentials.c', + 'tls_session_ticket-srv.c', + 'udp_queue.c', + 'watcher.c', + 'worker.c', + 'zimport.c', + 'sr_subscriptions.c', + 'dbus_control.c' +]) +c_src_lint += watcher_src + +watcher_deps = [ + contrib_dep, + kresconfig_dep, + libkres_dep, + libknot, + libzscanner, + libdnssec, + libuv, + luajit, + gnutls, + libsystemd, + capng, + libyang, + libsysrepo, +] + +subdir('lua') + +if build_sysrepo + kres_watcher = executable( + 'kres-watcher', + watcher_src, + sysrepo_common_src, + dependencies: watcher_deps, + export_dynamic: true, + install: true, + install_dir: get_option('sbindir'), + ) +endif diff --git a/utils/watcher/network.c b/utils/watcher/network.c new file mode 100644 index 000000000..5fb2f6604 --- /dev/null +++ b/utils/watcher/network.c @@ -0,0 +1,552 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "network.h" + +#include "bindings/impl.h" +#include "io.h" +#include "tls.h" +#include "worker.h" + +#include +#include +#include + +void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog) +{ + if (net != NULL) { + net->loop = loop; + net->endpoints = map_make(NULL); + net->endpoint_kinds = trie_create(NULL); + net->tls_client_params = NULL; + net->tls_session_ticket_ctx = /* unsync. random, by default */ + tls_session_ticket_ctx_create(loop, NULL, 0); + net->tcp.in_idle_timeout = 10000; + net->tcp.tls_handshake_timeout = TLS_MAX_HANDSHAKE_TIME; + net->tcp_backlog = tcp_backlog; + } +} + +/** Notify the registered function about endpoint getting open. + * If log_port < 1, don't log it. */ +static int endpoint_open_lua_cb(struct network *net, struct endpoint *ep, + const char *log_addr) +{ + const bool ok = ep->flags.kind && !ep->handle && !ep->engaged && ep->fd != -1; + if (!ok) { + assert(!EINVAL); + return kr_error(EINVAL); + } + /* First find callback in the endpoint registry. */ + lua_State *L = the_worker->engine->L; + void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind, + strlen(ep->flags.kind)); + if (!pp && net->missing_kind_is_error) { + kr_log_error("warning: network socket kind '%s' not handled when opening '%s", + ep->flags.kind, log_addr); + if (ep->family != AF_UNIX) + kr_log_error("#%d", ep->port); + kr_log_error("'. Likely causes: typo or not loading 'http' module.\n"); + /* No hard error, for now. LATER: perhaps differentiate between + * explicit net.listen() calls and "just unused" systemd sockets. + return kr_error(ENOENT); + */ + } + if (!pp) return kr_ok(); + + /* Now execute the callback. */ + const int fun_id = (char *)*pp - (char *)NULL; + lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id); + lua_pushboolean(L, true /* open */); + lua_pushpointer(L, ep); + if (ep->family == AF_UNIX) { + lua_pushstring(L, log_addr); + } else { + lua_pushfstring(L, "%s#%d", log_addr, ep->port); + } + if (lua_pcall(L, 3, 0, 0)) { + kr_log_error("error opening %s: %s\n", log_addr, lua_tostring(L, -1)); + return kr_error(ENOSYS); /* TODO: better value? */ + } + ep->engaged = true; + return kr_ok(); +} + +static int engage_endpoint_array(const char *key, void *endpoints, void *net) +{ + endpoint_array_t *eps = (endpoint_array_t *)endpoints; + for (int i = 0; i < eps->len; ++i) { + struct endpoint *ep = &eps->at[i]; + const bool match = !ep->engaged && ep->flags.kind; + if (!match) continue; + int ret = endpoint_open_lua_cb(net, ep, key); + if (ret) return ret; + } + return 0; +} +int network_engage_endpoints(struct network *net) +{ + if (net->missing_kind_is_error) + return kr_ok(); /* maybe weird, but let's make it idempotent */ + net->missing_kind_is_error = true; + int ret = map_walk(&net->endpoints, engage_endpoint_array, net); + if (ret) { + net->missing_kind_is_error = false; /* avoid the same errors when closing */ + return ret; + } + return kr_ok(); +} + + +/** Notify the registered function about endpoint about to be closed. */ +static void endpoint_close_lua_cb(struct network *net, struct endpoint *ep) +{ + lua_State *L = the_worker->engine->L; + void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind, + strlen(ep->flags.kind)); + if (!pp && net->missing_kind_is_error) { + kr_log_error("internal error: missing kind '%s' in endpoint registry\n", + ep->flags.kind); + return; + } + if (!pp) return; + + const int fun_id = (char *)*pp - (char *)NULL; + lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id); + lua_pushboolean(L, false /* close */); + lua_pushpointer(L, ep); + lua_pushstring(L, "FIXME:endpoint-identifier"); + if (lua_pcall(L, 3, 0, 0)) { + kr_log_error("failed to close FIXME:endpoint-identifier: %s\n", + lua_tostring(L, -1)); + } +} + +static void endpoint_close(struct network *net, struct endpoint *ep, bool force) +{ + assert(!ep->handle != !ep->flags.kind); + if (ep->family == AF_UNIX) { /* The FS name would be left behind. */ + /* Extract local address for this socket. */ + struct sockaddr_un sa; + sa.sun_path[0] = '\0'; /*< probably only for lint:scan-build */ + socklen_t addr_len = sizeof(sa); + if (getsockname(ep->fd, (struct sockaddr *)&sa, &addr_len) + || unlink(sa.sun_path)) { + kr_log_error("error (ignored) when closing unix socket (fd = %d): %s\n", + ep->fd, strerror(errno)); + return; + } + } + + if (ep->flags.kind) { /* Special endpoint. */ + if (ep->engaged) { + endpoint_close_lua_cb(net, ep); + } + if (ep->fd > 0) { + close(ep->fd); /* nothing to do with errors */ + } + free_const(ep->flags.kind); + return; + } + + if (force) { /* Force close if event loop isn't running. */ + if (ep->fd >= 0) { + close(ep->fd); + } + if (ep->handle) { + ep->handle->loop = NULL; + io_free(ep->handle); + } + } else { /* Asynchronous close */ + uv_close(ep->handle, io_free); + } +} + +/** Endpoint visitor (see @file map.h) */ +static int close_key(const char *key, void *val, void *net) +{ + endpoint_array_t *ep_array = val; + for (int i = 0; i < ep_array->len; ++i) { + endpoint_close(net, &ep_array->at[i], true); + } + return 0; +} + +static int free_key(const char *key, void *val, void *ext) +{ + endpoint_array_t *ep_array = val; + array_clear(*ep_array); + free(ep_array); + return kr_ok(); +} + +int kind_unregister(trie_val_t *tv, void *L) +{ + int fun_id = (char *)*tv - (char *)NULL; + luaL_unref(L, LUA_REGISTRYINDEX, fun_id); + return 0; +} + +void network_close_force(struct network *net) +{ + if (net != NULL) { + map_walk(&net->endpoints, close_key, net); + map_walk(&net->endpoints, free_key, 0); + map_clear(&net->endpoints); + } +} + +void network_deinit(struct network *net) +{ + if (net != NULL) { + network_close_force(net); + trie_apply(net->endpoint_kinds, kind_unregister, the_worker->engine->L); + trie_free(net->endpoint_kinds); + + tls_credentials_free(net->tls_credentials); + tls_client_params_free(net->tls_client_params); + tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx); + #ifndef NDEBUG + memset(net, 0, sizeof(*net)); + #endif + } +} + +/** Fetch or create endpoint array and insert endpoint (shallow memcpy). */ +static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep) +{ + /* Fetch or insert address into map */ + endpoint_array_t *ep_array = map_get(&net->endpoints, addr); + if (ep_array == NULL) { + ep_array = malloc(sizeof(*ep_array)); + if (ep_array == NULL) { + return kr_error(ENOMEM); + } + if (map_set(&net->endpoints, addr, ep_array) != 0) { + free(ep_array); + return kr_error(ENOMEM); + } + array_init(*ep_array); + } + + if (array_reserve(*ep_array, ep_array->len + 1)) { + return kr_error(ENOMEM); + } + memcpy(&ep_array->at[ep_array->len++], ep, sizeof(*ep)); + return kr_ok(); +} + +/** Open endpoint protocols. ep->flags were pre-set. */ +static int open_endpoint(struct network *net, struct endpoint *ep, + const struct sockaddr *sa, const char *log_addr) +{ + if ((sa != NULL) == (ep->fd != -1)) { + assert(!EINVAL); + return kr_error(EINVAL); + } + if (ep->handle) { + return kr_error(EEXIST); + } + + if (sa) { + ep->fd = io_bind(sa, ep->flags.sock_type, &ep->flags); + if (ep->fd < 0) return ep->fd; + } + if (ep->flags.kind) { + /* This EP isn't to be managed internally after binding. */ + return endpoint_open_lua_cb(net, ep, log_addr); + } else { + ep->engaged = true; + /* .engaged seems not really meaningful with .kind == NULL, but... */ + } + + if (ep->family == AF_UNIX) { + /* Some parts of connection handling would need more work, + * so let's support AF_UNIX only with .kind != NULL for now. */ + kr_log_error("[system] AF_UNIX only supported with set { kind = '...' }\n"); + return kr_error(EAFNOSUPPORT); + /* + uv_pipe_t *ep_handle = malloc(sizeof(uv_pipe_t)); + */ + } + + if (ep->flags.sock_type == SOCK_DGRAM) { + if (ep->flags.tls) { + assert(!EINVAL); + return kr_error(EINVAL); + } + uv_udp_t *ep_handle = malloc(sizeof(uv_udp_t)); + ep->handle = (uv_handle_t *)ep_handle; + if (!ep->handle) { + return kr_error(ENOMEM); + } + return io_listen_udp(net->loop, ep_handle, ep->fd); + } /* else */ + + if (ep->flags.sock_type == SOCK_STREAM) { + uv_tcp_t *ep_handle = malloc(sizeof(uv_tcp_t)); + ep->handle = (uv_handle_t *)ep_handle; + if (!ep->handle) { + return kr_error(ENOMEM); + } + return io_listen_tcp(net->loop, ep_handle, ep->fd, + net->tcp_backlog, ep->flags.tls); + } /* else */ + + assert(!EINVAL); + return kr_error(EINVAL); +} + +/** @internal Fetch a pointer to endpoint of given parameters (or NULL). + * Beware that there might be multiple matches, though that's not common. */ +static struct endpoint * endpoint_get(struct network *net, const char *addr, + uint16_t port, endpoint_flags_t flags) +{ + endpoint_array_t *ep_array = map_get(&net->endpoints, addr); + if (!ep_array) { + return NULL; + } + for (int i = 0; i < ep_array->len; ++i) { + struct endpoint *ep = &ep_array->at[i]; + if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) { + return ep; + } + } + return NULL; +} + +/** \note pass either sa != NULL xor ep.fd != -1; + * \note ownership of ep.flags.* is taken on success. */ +static int create_endpoint(struct network *net, const char *addr_str, + struct endpoint *ep, const struct sockaddr *sa) +{ + int ret = open_endpoint(net, ep, sa, addr_str); + if (ret == 0) { + ret = insert_endpoint(net, addr_str, ep); + } + if (ret != 0 && ep->handle) { + endpoint_close(net, ep, false); + } + return ret; +} + +int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags) +{ + /* Extract fd's socket type. */ + socklen_t len = sizeof(flags.sock_type); + int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len); + if (ret != 0) { + return kr_error(errno); + } + if (flags.sock_type == SOCK_DGRAM && !flags.kind && flags.tls) { + assert(!EINVAL); /* Perhaps DTLS some day. */ + return kr_error(EINVAL); + } + if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) { + return kr_error(EBADF); + } + + /* Extract local address for this socket. */ + struct sockaddr_storage ss = { .ss_family = AF_UNSPEC }; + socklen_t addr_len = sizeof(ss); + ret = getsockname(fd, (struct sockaddr *)&ss, &addr_len); + if (ret != 0) { + return kr_error(errno); + } + + struct endpoint ep = { + .flags = flags, + .family = ss.ss_family, + .fd = fd, + }; + /* Extract address string and port. */ + char addr_buf[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */ + const char *addr_str; + switch (ep.family) { + case AF_INET: + ret = uv_ip4_name((const struct sockaddr_in*)&ss, addr_buf, sizeof(addr_buf)); + addr_str = addr_buf; + ep.port = ntohs(((struct sockaddr_in *)&ss)->sin_port); + break; + case AF_INET6: + ret = uv_ip6_name((const struct sockaddr_in6*)&ss, addr_buf, sizeof(addr_buf)); + addr_str = addr_buf; + ep.port = ntohs(((struct sockaddr_in6 *)&ss)->sin6_port); + break; + case AF_UNIX: + /* No SOCK_DGRAM with AF_UNIX support, at least for now. */ + ret = flags.sock_type == SOCK_STREAM ? kr_ok() : kr_error(EAFNOSUPPORT); + addr_str = ((struct sockaddr_un *)&ss)->sun_path; + break; + default: + ret = kr_error(EAFNOSUPPORT); + } + if (ret) return ret; + + /* always create endpoint for supervisor supplied fd + * even if addr+port is not unique */ + return create_endpoint(net, addr_str, &ep, NULL); +} + +int network_listen(struct network *net, const char *addr, uint16_t port, + endpoint_flags_t flags) +{ + if (net == NULL || addr == 0 || port == 0) { + assert(!EINVAL); + return kr_error(EINVAL); + } + if (endpoint_get(net, addr, port, flags)) { + return kr_error(EADDRINUSE); /* Already listening */ + } + + /* Parse address. */ + const struct sockaddr *sa = kr_straddr_socket(addr, port, NULL); + if (!sa) { + return kr_error(EINVAL); + } + struct endpoint ep = { + .flags = flags, + .fd = -1, + .port = port, + .family = sa->sa_family, + }; + int ret = create_endpoint(net, addr, &ep, sa); + free_const(sa); + return ret; +} + +int network_close(struct network *net, const char *addr, int port) +{ + endpoint_array_t *ep_array = map_get(&net->endpoints, addr); + if (!ep_array) { + return kr_error(ENOENT); + } + + size_t i = 0; + bool matched = false; /*< at least one match */ + while (i < ep_array->len) { + struct endpoint *ep = &ep_array->at[i]; + if (port < 0 || ep->port == port) { + endpoint_close(net, ep, false); + array_del(*ep_array, i); + matched = true; + /* do not advance i */ + } else { + ++i; + } + } + if (!matched) { + return kr_error(ENOENT); + } + + /* Collapse key if it has no endpoint. */ + if (ep_array->len == 0) { + array_clear(*ep_array); + free(ep_array); + map_del(&net->endpoints, addr); + } + + return kr_ok(); +} + +void network_new_hostname(struct network *net, struct engine *engine) +{ + if (net->tls_credentials && + net->tls_credentials->ephemeral_servicename) { + struct tls_credentials *newcreds; + newcreds = tls_get_ephemeral_credentials(engine); + if (newcreds) { + tls_credentials_release(net->tls_credentials); + net->tls_credentials = newcreds; + kr_log_info("[tls] Updated ephemeral X.509 cert with new hostname\n"); + } else { + kr_log_error("[tls] Failed to update ephemeral X.509 cert with new hostname, using existing one\n"); + } + } +} + +#ifdef SO_ATTACH_BPF +static int set_bpf_cb(const char *key, void *val, void *ext) +{ + endpoint_array_t *endpoints = (endpoint_array_t *)val; + assert(endpoints != NULL); + int *bpffd = (int *)ext; + assert(bpffd != NULL); + + for (size_t i = 0; i < endpoints->len; i++) { + struct endpoint *endpoint = &endpoints->at[i]; + uv_os_fd_t sockfd = -1; + if (endpoint->handle != NULL) + uv_fileno(endpoint->handle, &sockfd); + assert(sockfd != -1); + + if (setsockopt(sockfd, SOL_SOCKET, SO_ATTACH_BPF, bpffd, sizeof(int)) != 0) { + return 1; /* return error (and stop iterating over net->endpoints) */ + } + } + return 0; /* OK */ +} +#endif + +int network_set_bpf(struct network *net, int bpf_fd) +{ +#ifdef SO_ATTACH_BPF + if (map_walk(&net->endpoints, set_bpf_cb, &bpf_fd) != 0) { + /* set_bpf_cb() has returned error. */ + network_clear_bpf(net); + return 0; + } +#else + kr_log_error("[network] SO_ATTACH_BPF socket option doesn't supported\n"); + (void)net; + (void)bpf_fd; + return 0; +#endif + return 1; +} + +#ifdef SO_DETACH_BPF +static int clear_bpf_cb(const char *key, void *val, void *ext) +{ + endpoint_array_t *endpoints = (endpoint_array_t *)val; + assert(endpoints != NULL); + + for (size_t i = 0; i < endpoints->len; i++) { + struct endpoint *endpoint = &endpoints->at[i]; + uv_os_fd_t sockfd = -1; + if (endpoint->handle != NULL) + uv_fileno(endpoint->handle, &sockfd); + assert(sockfd != -1); + + if (setsockopt(sockfd, SOL_SOCKET, SO_DETACH_BPF, NULL, 0) != 0) { + kr_log_error("[network] failed to clear SO_DETACH_BPF socket option\n"); + } + /* Proceed even if setsockopt() failed, + * as we want to process all opened sockets. */ + } + return 0; +} +#endif + +void network_clear_bpf(struct network *net) +{ +#ifdef SO_DETACH_BPF + map_walk(&net->endpoints, clear_bpf_cb, NULL); +#else + kr_log_error("[network] SO_DETACH_BPF socket option doesn't supported\n"); + (void)net; +#endif +} diff --git a/utils/watcher/network.h b/utils/watcher/network.h new file mode 100644 index 000000000..cee42c7ba --- /dev/null +++ b/utils/watcher/network.h @@ -0,0 +1,129 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include "tls.h" + +#include "lib/generic/array.h" +#include "lib/generic/map.h" +#include "lib/generic/trie.h" + +#include +#include + + +struct engine; + +/** Ways to listen on a socket. */ +typedef struct { + int sock_type; /**< SOCK_DGRAM or SOCK_STREAM */ + bool tls; /**< only used together with .kind == NULL and .tcp */ + const char *kind; /**< tag for other types than the three usual */ + bool freebind; /**< used for binding to non-local address **/ +} endpoint_flags_t; + +static inline bool endpoint_flags_eq(endpoint_flags_t f1, endpoint_flags_t f2) +{ + if (f1.sock_type != f2.sock_type) + return false; + if (f1.kind && f2.kind) + return strcasecmp(f1.kind, f2.kind); + else + return f1.tls == f2.tls && f1.kind == f2.kind; +} + +/** Wrapper for a single socket to listen on. + * There are two types: normal have handle, special have flags.kind (and never both). + * + * LATER: .family might be unexpected for IPv4-in-IPv6 addresses. + * ATM AF_UNIX is only supported with flags.kind != NULL + */ +struct endpoint { + uv_handle_t *handle; /**< uv_udp_t or uv_tcp_t; NULL in case flags.kind != NULL */ + int fd; /**< POSIX file-descriptor; always used. */ + int family; /**< AF_INET or AF_INET6 or AF_UNIX */ + uint16_t port; /**< TCP/UDP port. Meaningless with AF_UNIX. */ + bool engaged; /**< to some module or internally */ + endpoint_flags_t flags; +}; + +/** @cond internal Array of endpoints */ +typedef array_t(struct endpoint) endpoint_array_t; +/* @endcond */ + +struct net_tcp_param { + uint64_t in_idle_timeout; + uint64_t tls_handshake_timeout; +}; + +struct network { + uv_loop_t *loop; + + /** Map: address string -> endpoint_array_t. + * \note even same address-port-flags tuples may appear. + * TODO: trie_t, keyed on *binary* address-port pair. */ + map_t endpoints; + + /** Registry of callbacks for special endpoint kinds (for opening/closing). + * Map: kind (lowercased) -> lua function ID converted to void * + * The ID is the usual: raw int index in the LUA_REGISTRYINDEX table. */ + trie_t *endpoint_kinds; + /** See network_engage_endpoints() */ + bool missing_kind_is_error; + + struct tls_credentials *tls_credentials; + tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */ + struct tls_session_ticket_ctx *tls_session_ticket_ctx; + struct net_tcp_param tcp; + int tcp_backlog; +}; + +void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog); +void network_deinit(struct network *net); + +/** Start listenting on addr#port with flags. + * \note if we did listen on that combination already, + * nothing is done and kr_error(EADDRINUSE) is returned. + * \note there's no short-hand to listen both on UDP and TCP. + * \note ownership of flags.* is taken on success. TODO: non-success? + */ +int network_listen(struct network *net, const char *addr, uint16_t port, + endpoint_flags_t flags); + +/** Start listenting on an open file-descriptor. + * \note flags.sock_type isn't meaningful here. + * \note ownership of flags.* is taken on success. TODO: non-success? + */ +int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags); + +/** Stop listening on all endpoints with matching addr#port. + * port < 0 serves as a wild-card. + * \return kr_error(ENOENT) if nothing matched. */ +int network_close(struct network *net, const char *addr, int port); + +/** Close all endpoints immediately (no waiting for UV loop). */ +void network_close_force(struct network *net); + +/** Enforce that all endpoints are registered from now on. + * This only does anything with struct endpoint::flags.kind != NULL. */ +int network_engage_endpoints(struct network *net); + +int network_set_tls_cert(struct network *net, const char *cert); +int network_set_tls_key(struct network *net, const char *key); +void network_new_hostname(struct network *net, struct engine *engine); +int network_set_bpf(struct network *net, int bpf_fd); +void network_clear_bpf(struct network *net); diff --git a/utils/watcher/session.c b/utils/watcher/session.c new file mode 100644 index 000000000..2ffc25980 --- /dev/null +++ b/utils/watcher/session.c @@ -0,0 +1,776 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include + +#include + +#include "lib/defines.h" +#include "session.h" +#include "engine.h" +#include "tls.h" +#include "worker.h" +#include "io.h" +#include "lib/generic/queue.h" + +#define TLS_CHUNK_SIZE (16 * 1024) + +/* Per-socket (TCP or UDP) persistent structure. + * + * In particular note that for UDP clients it's just one session (per socket) + * shared for all clients. For TCP/TLS it's for the connection-specific socket, + * i.e one session per connection. + */ +struct session { + struct session_flags sflags; /**< miscellaneous flags. */ + union inaddr peer; /**< address of peer; not for UDP clients (downstream) */ + union inaddr sockname; /**< our local address; for UDP it may be a wildcard */ + uv_handle_t *handle; /**< libuv handle for IO operations. */ + uv_timer_t timeout; /**< libuv handle for timer. */ + + struct tls_ctx_t *tls_ctx; /**< server side tls-related data. */ + struct tls_client_ctx_t *tls_client_ctx; /**< client side tls-related data. */ + + trie_t *tasks; /**< list of tasks assotiated with given session. */ + queue_t(struct qr_task *) waiting; /**< list of tasks waiting for sending to upstream. */ + + uint8_t *wire_buf; /**< Buffer for DNS message. */ + ssize_t wire_buf_size; /**< Buffer size. */ + ssize_t wire_buf_start_idx; /**< Data start offset in wire_buf. */ + ssize_t wire_buf_end_idx; /**< Data end offset in wire_buf. */ + uint64_t last_activity; /**< Time of last IO activity (if any occurs). + * Otherwise session creation time. */ +}; + +static void on_session_close(uv_handle_t *handle) +{ + struct session *session = handle->data; + assert(session->handle == handle); (void)session; + io_free(handle); +} + +static void on_session_timer_close(uv_handle_t *timer) +{ + struct session *session = timer->data; + uv_handle_t *handle = session->handle; + assert(handle && handle->data == session); + assert (session->sflags.outgoing || handle->type == UV_TCP); + if (!uv_is_closing(handle)) { + uv_close(handle, on_session_close); + } +} + +void session_free(struct session *session) +{ + if (session) { + assert(session_is_empty(session)); + session_clear(session); + free(session); + } +} + +void session_clear(struct session *session) +{ + assert(session_is_empty(session)); + if (session->handle && session->handle->type == UV_TCP) { + free(session->wire_buf); + } + trie_clear(session->tasks); + trie_free(session->tasks); + queue_deinit(session->waiting); + tls_free(session->tls_ctx); + tls_client_ctx_free(session->tls_client_ctx); + memset(session, 0, sizeof(*session)); +} + +void session_close(struct session *session) +{ + assert(session_is_empty(session)); + if (session->sflags.closing) { + return; + } + + uv_handle_t *handle = session->handle; + io_stop_read(handle); + session->sflags.closing = true; + + if (!uv_is_closing((uv_handle_t *)&session->timeout)) { + uv_timer_stop(&session->timeout); + if (session->tls_client_ctx) { + tls_close(&session->tls_client_ctx->c); + } + if (session->tls_ctx) { + tls_close(&session->tls_ctx->c); + } + + session->timeout.data = session; + uv_close((uv_handle_t *)&session->timeout, on_session_timer_close); + } +} + +int session_start_read(struct session *session) +{ + return io_start_read(session->handle); +} + +int session_stop_read(struct session *session) +{ + return io_stop_read(session->handle); +} + +int session_waitinglist_push(struct session *session, struct qr_task *task) +{ + queue_push(session->waiting, task); + worker_task_ref(task); + return kr_ok(); +} + +struct qr_task *session_waitinglist_get(const struct session *session) +{ + return (queue_len(session->waiting) > 0) ? (queue_head(session->waiting)) : NULL; +} + +struct qr_task *session_waitinglist_pop(struct session *session, bool deref) +{ + struct qr_task *t = session_waitinglist_get(session); + queue_pop(session->waiting); + if (deref) { + worker_task_unref(t); + } + return t; +} + +int session_tasklist_add(struct session *session, struct qr_task *task) +{ + trie_t *t = session->tasks; + uint16_t task_msg_id = 0; + const char *key = NULL; + size_t key_len = 0; + if (session->sflags.outgoing) { + knot_pkt_t *pktbuf = worker_task_get_pktbuf(task); + task_msg_id = knot_wire_get_id(pktbuf->wire); + key = (const char *)&task_msg_id; + key_len = sizeof(task_msg_id); + } else { + key = (const char *)&task; + key_len = sizeof(char *); + } + trie_val_t *v = trie_get_ins(t, key, key_len); + if (unlikely(!v)) { + assert(false); + return kr_error(ENOMEM); + } + if (*v == NULL) { + *v = task; + worker_task_ref(task); + } else if (*v != task) { + assert(false); + return kr_error(EINVAL); + } + return kr_ok(); +} + +int session_tasklist_del(struct session *session, struct qr_task *task) +{ + trie_t *t = session->tasks; + uint16_t task_msg_id = 0; + const char *key = NULL; + size_t key_len = 0; + trie_val_t val; + if (session->sflags.outgoing) { + knot_pkt_t *pktbuf = worker_task_get_pktbuf(task); + task_msg_id = knot_wire_get_id(pktbuf->wire); + key = (const char *)&task_msg_id; + key_len = sizeof(task_msg_id); + } else { + key = (const char *)&task; + key_len = sizeof(char *); + } + int ret = trie_del(t, key, key_len, &val); + if (ret == kr_ok()) { + assert(val == task); + worker_task_unref(val); + } + return ret; +} + +struct qr_task *session_tasklist_get_first(struct session *session) +{ + trie_val_t *val = trie_get_first(session->tasks, NULL, NULL); + return val ? (struct qr_task *) *val : NULL; +} + +struct qr_task *session_tasklist_del_first(struct session *session, bool deref) +{ + trie_val_t val = NULL; + int res = trie_del_first(session->tasks, NULL, NULL, &val); + if (res != kr_ok()) { + val = NULL; + } else if (deref) { + worker_task_unref(val); + } + return (struct qr_task *)val; +} +struct qr_task* session_tasklist_del_msgid(const struct session *session, uint16_t msg_id) +{ + trie_t *t = session->tasks; + assert(session->sflags.outgoing); + struct qr_task *ret = NULL; + const char *key = (const char *)&msg_id; + size_t key_len = sizeof(msg_id); + trie_val_t val; + int res = trie_del(t, key, key_len, &val); + if (res == kr_ok()) { + if (worker_task_numrefs(val) > 1) { + ret = val; + } + worker_task_unref(val); + } + return ret; +} + +struct qr_task* session_tasklist_find_msgid(const struct session *session, uint16_t msg_id) +{ + trie_t *t = session->tasks; + assert(session->sflags.outgoing); + struct qr_task *ret = NULL; + trie_val_t *val = trie_get_try(t, (char *)&msg_id, sizeof(msg_id)); + if (val) { + ret = *val; + } + return ret; +} + +struct session_flags *session_flags(struct session *session) +{ + return &session->sflags; +} + +struct sockaddr *session_get_peer(struct session *session) +{ + return &session->peer.ip; +} + +struct sockaddr *session_get_sockname(struct session *session) +{ + return &session->sockname.ip; +} + +struct tls_ctx_t *session_tls_get_server_ctx(const struct session *session) +{ + return session->tls_ctx; +} + +void session_tls_set_server_ctx(struct session *session, struct tls_ctx_t *ctx) +{ + session->tls_ctx = ctx; +} + +struct tls_client_ctx_t *session_tls_get_client_ctx(const struct session *session) +{ + return session->tls_client_ctx; +} + +void session_tls_set_client_ctx(struct session *session, struct tls_client_ctx_t *ctx) +{ + session->tls_client_ctx = ctx; +} + +struct tls_common_ctx *session_tls_get_common_ctx(const struct session *session) +{ + struct tls_common_ctx *tls_ctx = session->sflags.outgoing ? &session->tls_client_ctx->c : + &session->tls_ctx->c; + return tls_ctx; +} + +uv_handle_t *session_get_handle(struct session *session) +{ + return session->handle; +} + +struct session *session_get(uv_handle_t *h) +{ + return h->data; +} + +struct session *session_new(uv_handle_t *handle, bool has_tls) +{ + if (!handle) { + return NULL; + } + struct session *session = calloc(1, sizeof(struct session)); + if (!session) { + return NULL; + } + + queue_init(session->waiting); + session->tasks = trie_create(NULL); + if (handle->type == UV_TCP) { + size_t wire_buffer_size = KNOT_WIRE_MAX_PKTSIZE; + if (has_tls) { + /* When decoding large packets, + * gnutls gives the application chunks of size 16 kb each. */ + wire_buffer_size += TLS_CHUNK_SIZE; + session->sflags.has_tls = true; + } + uint8_t *wire_buf = malloc(wire_buffer_size); + if (!wire_buf) { + free(session); + return NULL; + } + session->wire_buf = wire_buf; + session->wire_buf_size = wire_buffer_size; + } else if (handle->type == UV_UDP) { + /* We use the singleton buffer from worker for all UDP (!) + * libuv documentation doesn't really guarantee this is OK, + * but the implementation for unix systems does not hold + * the buffer (both UDP and TCP) - always makes a NON-blocking + * syscall that fills the buffer and immediately calls + * the callback, whatever the result of the operation. + * We still need to keep in mind to only touch the buffer + * in this callback... */ + assert(handle->loop->data); + struct worker_ctx *worker = handle->loop->data; + session->wire_buf = worker->wire_buf; + session->wire_buf_size = sizeof(worker->wire_buf); + } + + uv_timer_init(handle->loop, &session->timeout); + + session->handle = handle; + handle->data = session; + session->timeout.data = session; + session_touch(session); + + return session; +} + +size_t session_tasklist_get_len(const struct session *session) +{ + return trie_weight(session->tasks); +} + +size_t session_waitinglist_get_len(const struct session *session) +{ + return queue_len(session->waiting); +} + +bool session_tasklist_is_empty(const struct session *session) +{ + return session_tasklist_get_len(session) == 0; +} + +bool session_waitinglist_is_empty(const struct session *session) +{ + return session_waitinglist_get_len(session) == 0; +} + +bool session_is_empty(const struct session *session) +{ + return session_tasklist_is_empty(session) && + session_waitinglist_is_empty(session); +} + +bool session_has_tls(const struct session *session) +{ + return session->sflags.has_tls; +} + +void session_set_has_tls(struct session *session, bool has_tls) +{ + session->sflags.has_tls = has_tls; +} + +void session_waitinglist_retry(struct session *session, bool increase_timeout_cnt) +{ + while (!session_waitinglist_is_empty(session)) { + struct qr_task *task = session_waitinglist_pop(session, false); + if (increase_timeout_cnt) { + worker_task_timeout_inc(task); + } + worker_task_step(task, &session->peer.ip, NULL); + worker_task_unref(task); + } +} + +void session_waitinglist_finalize(struct session *session, int status) +{ + while (!session_waitinglist_is_empty(session)) { + struct qr_task *t = session_waitinglist_pop(session, false); + worker_task_finalize(t, status); + worker_task_unref(t); + } +} + +void session_tasklist_finalize(struct session *session, int status) +{ + while (session_tasklist_get_len(session) > 0) { + struct qr_task *t = session_tasklist_del_first(session, false); + assert(worker_task_numrefs(t) > 0); + worker_task_finalize(t, status); + worker_task_unref(t); + } +} + +int session_tasklist_finalize_expired(struct session *session) +{ + int ret = 0; + queue_t(struct qr_task *) q; + uint64_t now = kr_now(); + trie_t *t = session->tasks; + trie_it_t *it; + queue_init(q); + for (it = trie_it_begin(t); !trie_it_finished(it); trie_it_next(it)) { + trie_val_t *v = trie_it_val(it); + struct qr_task *task = (struct qr_task *)*v; + if ((now - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) { + queue_push(q, task); + worker_task_ref(task); + } + } + trie_it_free(it); + + struct qr_task *task = NULL; + uint16_t msg_id = 0; + char *key = (char *)&task; + int32_t keylen = sizeof(struct qr_task *); + if (session->sflags.outgoing) { + key = (char *)&msg_id; + keylen = sizeof(msg_id); + } + while (queue_len(q) > 0) { + task = queue_head(q); + if (session->sflags.outgoing) { + knot_pkt_t *pktbuf = worker_task_get_pktbuf(task); + msg_id = knot_wire_get_id(pktbuf->wire); + } + int res = trie_del(t, key, keylen, NULL); + if (!worker_task_finished(task)) { + /* task->pending_count must be zero, + * but there are can be followers, + * so run worker_task_subreq_finalize() to ensure retrying + * for all the followers. */ + worker_task_subreq_finalize(task); + worker_task_finalize(task, KR_STATE_FAIL); + } + if (res == KNOT_EOK) { + worker_task_unref(task); + } + queue_pop(q); + worker_task_unref(task); + ++ret; + } + + queue_deinit(q); + return ret; +} + +int session_timer_start(struct session *session, uv_timer_cb cb, + uint64_t timeout, uint64_t repeat) +{ + uv_timer_t *timer = &session->timeout; + assert(timer->data == session); + int ret = uv_timer_start(timer, cb, timeout, repeat); + if (ret != 0) { + uv_timer_stop(timer); + return kr_error(ENOMEM); + } + return 0; +} + +int session_timer_restart(struct session *session) +{ + return uv_timer_again(&session->timeout); +} + +int session_timer_stop(struct session *session) +{ + return uv_timer_stop(&session->timeout); +} + +ssize_t session_wirebuf_consume(struct session *session, const uint8_t *data, ssize_t len) +{ + if (data != &session->wire_buf[session->wire_buf_end_idx]) { + /* shouldn't happen */ + return kr_error(EINVAL); + } + + if (len < 0) { + /* shouldn't happen */ + return kr_error(EINVAL); + } + + if (session->wire_buf_end_idx + len > session->wire_buf_size) { + /* shouldn't happen */ + return kr_error(EINVAL); + } + + session->wire_buf_end_idx += len; + return len; +} + +knot_pkt_t *session_produce_packet(struct session *session, knot_mm_t *mm) +{ + session->sflags.wirebuf_error = false; + if (session->wire_buf_end_idx == 0) { + return NULL; + } + + if (session->wire_buf_start_idx == session->wire_buf_end_idx) { + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; + return NULL; + } + + if (session->wire_buf_start_idx > session->wire_buf_end_idx) { + session->sflags.wirebuf_error = true; + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; + return NULL; + } + + const uv_handle_t *handle = session->handle; + uint8_t *msg_start = &session->wire_buf[session->wire_buf_start_idx]; + ssize_t wirebuf_msg_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx; + uint16_t msg_size = 0; + + if (!handle) { + session->sflags.wirebuf_error = true; + return NULL; + } else if (handle->type == UV_TCP) { + if (wirebuf_msg_data_size < 2) { + return NULL; + } + msg_size = knot_wire_read_u16(msg_start); + if (msg_size >= session->wire_buf_size) { + session->sflags.wirebuf_error = true; + return NULL; + } + if (msg_size + 2 > wirebuf_msg_data_size) { + return NULL; + } + if (msg_size == 0) { + session->sflags.wirebuf_error = true; + return NULL; + } + msg_start += 2; + } else if (wirebuf_msg_data_size < UINT16_MAX) { + msg_size = wirebuf_msg_data_size; + } else { + session->sflags.wirebuf_error = true; + return NULL; + } + + + knot_pkt_t *pkt = knot_pkt_new(msg_start, msg_size, mm); + session->sflags.wirebuf_error = (pkt == NULL); + return pkt; +} + +int session_discard_packet(struct session *session, const knot_pkt_t *pkt) +{ + uv_handle_t *handle = session->handle; + /* Pointer to data start in wire_buf */ + uint8_t *wirebuf_data_start = &session->wire_buf[session->wire_buf_start_idx]; + /* Number of data bytes in wire_buf */ + size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx; + /* Pointer to message start in wire_buf */ + uint8_t *wirebuf_msg_start = wirebuf_data_start; + /* Number of message bytes in wire_buf. + * For UDP it is the same number as wirebuf_data_size. */ + size_t wirebuf_msg_size = wirebuf_data_size; + /* Wire data from parsed packet. */ + uint8_t *pkt_msg_start = pkt->wire; + /* Number of bytes in packet wire buffer. */ + size_t pkt_msg_size = pkt->size; + if (knot_pkt_has_tsig(pkt)) { + pkt_msg_size += pkt->tsig_wire.len; + } + + session->sflags.wirebuf_error = true; + + if (!handle) { + return kr_error(EINVAL); + } else if (handle->type == UV_TCP) { + /* wire_buf contains TCP DNS message. */ + if (wirebuf_data_size < 2) { + /* TCP message length field isn't in buffer, must not happen. */ + assert(0); + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; + return kr_error(EINVAL); + } + wirebuf_msg_size = knot_wire_read_u16(wirebuf_msg_start); + wirebuf_msg_start += 2; + if (wirebuf_msg_size + 2 > wirebuf_data_size) { + /* TCP message length field is greater then + * number of bytes in buffer, must not happen. */ + assert(0); + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; + return kr_error(EINVAL); + } + } + + if (wirebuf_msg_start != pkt_msg_start) { + /* packet wirebuf must be located at the beginning + * of the session wirebuf, must not happen. */ + assert(0); + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; + return kr_error(EINVAL); + } + + if (wirebuf_msg_size < pkt_msg_size) { + /* Message length field is lesser then packet size, + * must not happen. */ + assert(0); + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; + return kr_error(EINVAL); + } + + if (handle->type == UV_TCP) { + session->wire_buf_start_idx += wirebuf_msg_size + 2; + } else { + session->wire_buf_start_idx += pkt_msg_size; + } + session->sflags.wirebuf_error = false; + + wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx; + if (wirebuf_data_size == 0) { + session_wirebuf_discard(session); + } else if (wirebuf_data_size < KNOT_WIRE_HEADER_SIZE) { + session_wirebuf_compress(session); + } + + return kr_ok(); +} + +void session_wirebuf_discard(struct session *session) +{ + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = 0; +} + +void session_wirebuf_compress(struct session *session) +{ + if (session->wire_buf_start_idx == 0) { + return; + } + uint8_t *wirebuf_data_start = &session->wire_buf[session->wire_buf_start_idx]; + size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx; + if (session->wire_buf_start_idx < wirebuf_data_size) { + memmove(session->wire_buf, wirebuf_data_start, wirebuf_data_size); + } else { + memcpy(session->wire_buf, wirebuf_data_start, wirebuf_data_size); + } + session->wire_buf_start_idx = 0; + session->wire_buf_end_idx = wirebuf_data_size; +} + +bool session_wirebuf_error(struct session *session) +{ + return session->sflags.wirebuf_error; +} + +uint8_t *session_wirebuf_get_start(struct session *session) +{ + return session->wire_buf; +} + +size_t session_wirebuf_get_size(struct session *session) +{ + return session->wire_buf_size; +} + +uint8_t *session_wirebuf_get_free_start(struct session *session) +{ + return &session->wire_buf[session->wire_buf_end_idx]; +} + +size_t session_wirebuf_get_free_size(struct session *session) +{ + return session->wire_buf_size - session->wire_buf_end_idx; +} + +void session_poison(struct session *session) +{ + kr_asan_poison(session, sizeof(*session)); +} + +void session_unpoison(struct session *session) +{ + kr_asan_unpoison(session, sizeof(*session)); +} + +int session_wirebuf_process(struct session *session, const struct sockaddr *peer) +{ + int ret = 0; + if (session->wire_buf_start_idx == session->wire_buf_end_idx) { + return ret; + } + struct worker_ctx *worker = session_get_handle(session)->loop->data; + size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx; + uint32_t max_iterations = (wirebuf_data_size / (KNOT_WIRE_HEADER_SIZE + KNOT_WIRE_QUESTION_MIN_SIZE)) + 1; + knot_pkt_t *query = NULL; + while (((query = session_produce_packet(session, &worker->pkt_pool)) != NULL) && + (ret < max_iterations)) { + assert (!session_wirebuf_error(session)); + int res = worker_submit(session, peer, query); + if (res != kr_error(EILSEQ)) { + /* Packet has been successfully parsed. */ + ret += 1; + } + if (session_discard_packet(session, query) < 0) { + /* Packet data isn't stored in memory as expected. + something went wrong, normally should not happen. */ + break; + } + } + if (session_wirebuf_error(session)) { + ret = -1; + } + return ret; +} + +void session_kill_ioreq(struct session *s, struct qr_task *task) +{ + if (!s) { + return; + } + assert(s->sflags.outgoing && s->handle); + if (s->sflags.closing) { + return; + } + session_tasklist_del(s, task); + if (s->handle->type == UV_UDP) { + assert(session_tasklist_is_empty(s)); + session_close(s); + return; + } +} + +/** Update timestamp */ +void session_touch(struct session *s) +{ + s->last_activity = kr_now(); +} + +uint64_t session_last_activity(struct session *s) +{ + return s->last_activity; +} diff --git a/utils/watcher/session.h b/utils/watcher/session.h new file mode 100644 index 000000000..4662b53b1 --- /dev/null +++ b/utils/watcher/session.h @@ -0,0 +1,151 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include + +#include +#include + +struct qr_task; +struct worker_ctx; +struct session; + +struct session_flags { + bool outgoing : 1; /**< True: to upstream; false: from a client. */ + bool throttled : 1; /**< True: data reading from peer is temporarily stopped. */ + bool has_tls : 1; /**< True: given session uses TLS. */ + bool connected : 1; /**< True: TCP connection is established. */ + bool closing : 1; /**< True: session close sequence is in progress. */ + bool wirebuf_error : 1; /**< True: last operation with wirebuf ended up with an error. */ +}; + +/* Allocate new session for a libuv handle. + * If handle->tyoe is UV_UDP, tls parameter will be ignored. */ +struct session *session_new(uv_handle_t *handle, bool has_tls); +/* Clear and free given session. */ +void session_free(struct session *session); +/* Clear session. */ +void session_clear(struct session *session); +/** Close session. */ +void session_close(struct session *session); +/** Start reading from underlying libuv IO handle. */ +int session_start_read(struct session *session); +/** Stop reading from underlying libuv IO handle. */ +int session_stop_read(struct session *session); + +/** List of tasks been waiting for IO. */ +/** Check if list is empty. */ +bool session_waitinglist_is_empty(const struct session *session); +/** Add task to the end of the list. */ +int session_waitinglist_push(struct session *session, struct qr_task *task); +/** Get the first element. */ +struct qr_task *session_waitinglist_get(const struct session *session); +/** Get the first element and remove it from the list. */ +struct qr_task *session_waitinglist_pop(struct session *session, bool deref); +/** Get the list length. */ +size_t session_waitinglist_get_len(const struct session *session); +/** Retry resolution for each task in the list. */ +void session_waitinglist_retry(struct session *session, bool increase_timeout_cnt); +/** Finalize all tasks in the list. */ +void session_waitinglist_finalize(struct session *session, int status); + +/** List of tasks associated with session. */ +/** Check if list is empty. */ +bool session_tasklist_is_empty(const struct session *session); +/** Get the first element. */ +struct qr_task *session_tasklist_get_first(struct session *session); +/** Get the first element and remove it from the list. */ +struct qr_task *session_tasklist_del_first(struct session *session, bool deref); +/** Get the list length. */ +size_t session_tasklist_get_len(const struct session *session); +/** Add task to the list. */ +int session_tasklist_add(struct session *session, struct qr_task *task); +/** Remove task from the list. */ +int session_tasklist_del(struct session *session, struct qr_task *task); +/** Remove task with given msg_id, session_flags(session)->outgoing must be true. */ +struct qr_task* session_tasklist_del_msgid(const struct session *session, uint16_t msg_id); +/** Find task with given msg_id */ +struct qr_task* session_tasklist_find_msgid(const struct session *session, uint16_t msg_id); +/** Finalize all tasks in the list. */ +void session_tasklist_finalize(struct session *session, int status); +/** Finalize all expired tasks in the list. */ +int session_tasklist_finalize_expired(struct session *session); + +/** Both of task lists (associated & waiting). */ +/** Check if empty. */ +bool session_is_empty(const struct session *session); +/** Get pointer to session flags */ +struct session_flags *session_flags(struct session *session); +/** Get pointer to peer address. */ +struct sockaddr *session_get_peer(struct session *session); +/** Get pointer to sockname (address of our end, not meaningful for UDP downstream). */ +struct sockaddr *session_get_sockname(struct session *session); +/** Get pointer to server-side tls-related data. */ +struct tls_ctx_t *session_tls_get_server_ctx(const struct session *session); +/** Set pointer to server-side tls-related data. */ +void session_tls_set_server_ctx(struct session *session, struct tls_ctx_t *ctx); +/** Get pointer to client-side tls-related data. */ +struct tls_client_ctx_t *session_tls_get_client_ctx(const struct session *session); +/** Set pointer to client-side tls-related data. */ +void session_tls_set_client_ctx(struct session *session, struct tls_client_ctx_t *ctx); +/** Get pointer to that part of tls-related data which has common structure for + * server and client. */ +struct tls_common_ctx *session_tls_get_common_ctx(const struct session *session); + +/** Get pointer to underlying libuv handle for IO operations. */ +uv_handle_t *session_get_handle(struct session *session); +struct session *session_get(uv_handle_t *h); + +/** Start session timer. */ +int session_timer_start(struct session *session, uv_timer_cb cb, + uint64_t timeout, uint64_t repeat); +/** Restart session timer without changing it parameters. */ +int session_timer_restart(struct session *session); +/** Stop session timer. */ +int session_timer_stop(struct session *session); + +/** Get pointer to the beginning of session wirebuffer. */ +uint8_t *session_wirebuf_get_start(struct session *session); +/** Get size of session wirebuffer. */ +size_t session_wirebuf_get_size(struct session *session); +/** Get pointer to the beginning of free space in session wirebuffer. */ +uint8_t *session_wirebuf_get_free_start(struct session *session); +/** Get amount of free space in session wirebuffer. */ +size_t session_wirebuf_get_free_size(struct session *session); +/** Discard all data in session wirebuffer. */ +void session_wirebuf_discard(struct session *session); +/** Move all data to the beginning of the buffer. */ +void session_wirebuf_compress(struct session *session); +int session_wirebuf_process(struct session *session, const struct sockaddr *peer); +ssize_t session_wirebuf_consume(struct session *session, + const uint8_t *data, ssize_t len); + +/** poison session structure with ASAN. */ +void session_poison(struct session *session); +/** unpoison session structure with ASAN. */ +void session_unpoison(struct session *session); + +knot_pkt_t *session_produce_packet(struct session *session, knot_mm_t *mm); +int session_discard_packet(struct session *session, const knot_pkt_t *pkt); + +void session_kill_ioreq(struct session *s, struct qr_task *task); +/** Update timestamp */ +void session_touch(struct session *s); +/** Returns either creation time or time of last IO activity if any occurs. */ +/* Used for TCP timeout calculation. */ +uint64_t session_last_activity(struct session *s); diff --git a/utils/watcher/sr_subscriptions.c b/utils/watcher/sr_subscriptions.c new file mode 100644 index 000000000..f91e50a91 --- /dev/null +++ b/utils/watcher/sr_subscriptions.c @@ -0,0 +1,435 @@ +#include +#include +#include + +#include "kresconfig.h" + +#include "sr_subscriptions.h" +#include "dbus_control.h" +#include "watcher.h" +#include "worker.h" + +#include "modules/sysrepo/common/sysrepo.h" +#include "modules/sysrepo/common/string_helper.h" + +#define XPATH_SERVER XPATH_BASE"/server" +#define XPATH_INSTANCES XPATH_BASE"/"YM_KRES":instances" +#define XPATH_TST_SECRET XPATH_BASE"/network/tls/"YM_KRES":sticket-secret" +#define XPATH_CACHE_PREFILL XPATH_BASE"/cache/"YM_KRES":prefill" +#define XPATH_STATUS XPATH_SERVER"/"YM_KRES":status" +#define XPATH_VERSION XPATH_SERVER"/package-version" + + +static int kresd_instances_start(const char *method) +{ + char xpath[128]; + int ret = SR_ERR_OK; + sr_val_t *vals = NULL; + size_t i, val_count = 0; + sysrepo_uv_ctx_t *sysrepo = the_worker->engine->watcher.sysrepo; + int kresd_instances = the_worker->engine->watcher.config.kresd_instances; + + ret = sr_get_items(sysrepo->session, XPATH_BASE"/"YM_KRES":instances//name", 0, 0, &vals, &val_count); + + for (i = 0; i < kresd_instances; ++i) { + + char inst_name[128]; + if (i < val_count) + sprintf(&inst_name, "%s", vals[i].data.string_val); + else + sprintf(&inst_name, "%ld", i); + + kresd_ctl(method,inst_name); + } + sr_free_values(vals, val_count); +} + +static int kresd_instances_status(struct lyd_node **parent) +{ + char xpath[128]; + int ret = SR_ERR_OK; + sr_val_t *vals = NULL; + size_t i, val_count = 0; + sysrepo_uv_ctx_t *sysrepo = the_worker->engine->watcher.sysrepo; + int kresd_instances = the_worker->engine->watcher.config.kresd_instances; + + ret = sr_get_items(sysrepo->session, XPATH_BASE"/"YM_KRES":instances//name", 0, 0, &vals, &val_count); + + for (i = 0; i < kresd_instances; ++i) { + + char inst_name[128]; + char *status; + + if (i < val_count) + sprintf(&inst_name, "%s", vals[i].data.string_val); + else + sprintf(&inst_name, "%ld", i); + + sprintf(&xpath, XPATH_STATUS"/kresd-instances[name='%s']", inst_name); + kresd_get_status(inst_name, &status); + lyd_new_path(*parent, NULL, xpath, inst_name, 0, 0); + + sprintf(&xpath, XPATH_STATUS"/kresd-instances[name='%s']/status", inst_name); + lyd_new_path(*parent, NULL, xpath, status, 0, 0); + + free(status); + } + sr_free_values(vals, val_count); +} + +int resolver_start() +{ + int ret = SR_ERR_OK; + struct server_config cfg = the_worker->engine->watcher.config; + + kresd_instances_start(UNIT_START); + + if (cfg.auto_cache_gc) + cache_gc_ctl(UNIT_START); + + return ret; +} + +int set_tst_secret(const char *new_secret) +{ + int ret = 0; + sr_conn_ctx_t *connection = NULL; + sr_session_ctx_t *session = NULL; + + if (!ret) ret = sr_connect(0, &connection); + if (!ret) ret = sr_session_start(connection, SR_DS_RUNNING, &session); + if (!ret) ret = sr_set_item_str(session, XPATH_TST_SECRET, new_secret, NULL, 0); + if (!ret) ret = sr_validate(session, YM_COMMON, 0); + if (!ret) ret = sr_apply_changes(session, 0, 0); + if (ret) + kr_log_error( + "[sysrepo] failed to set '%s', %s\n", + XPATH_TST_SECRET, sr_strerror(ret)); + + sr_disconnect(connection); + + return ret; +} + +/* STATE DATA CALLBACKS */ + +static int server_status_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath, +const char *request_xpath, uint32_t request_id, struct lyd_node **parent, void *private_data) +{ + char *cache_gc_status; + + /* kresd instances status */ + kresd_instances_status(parent); + + /* Cache Garbage Collector status */ + cache_gc_get_status(&cache_gc_status); + lyd_new_path(*parent, NULL, XPATH_STATUS"/cache-gc", cache_gc_status, 0, 0); + + free(cache_gc_status); + return SR_ERR_OK; +} + +static int server_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath, +const char *request_xpath, uint32_t request_id, struct lyd_node **parent, void *private_data) +{ + assert(parent!=NULL); + + char str[128]; + struct server_config cfg = the_worker->engine->watcher.config; + lyd_new_path(*parent, NULL, XPATH_VERSION, PACKAGE_VERSION, 0, 0); + + sprintf(str, "%s", cfg.auto_start ? "true" : "false"); + lyd_new_path(*parent, NULL, XPATH_SERVER"/"YM_KRES":auto-start", str, 0, 0); + + sprintf(str, "%s", cfg.auto_cache_gc ? "true" : "false"); + lyd_new_path(*parent, NULL, XPATH_SERVER"/"YM_KRES":auto-cache-gc", str, 0, 0); + + sprintf(str, "%d", cfg.kresd_instances); + lyd_new_path(*parent, NULL, XPATH_SERVER"/"YM_KRES":kresd-instances", str, 0, 0); + + return SR_ERR_OK; +} + +static int instance_status_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath, +const char *request_xpath, uint32_t request_id, struct lyd_node **parent, void *private_data) +{ + return SR_ERR_OK; +} + +/* CONFIG DATA CALLBACKS */ + +static int server_change_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath, +sr_event_t event, uint32_t request_id, void *private_data) +{ + if(event == SR_EV_CHANGE) + { + /* validation actions*/ + } + else if (event == SR_EV_DONE) + { + int err = SR_ERR_OK; + sr_change_oper_t oper; + sr_val_t *old_value = NULL; + sr_val_t *new_value = NULL; + sr_change_iter_t *it = NULL; + struct server_config *cfg = &the_worker->engine->watcher.config; + + err = sr_get_changes_iter(session, XPATH_SERVER "/*/.", &it); + if (err != SR_ERR_OK) goto cleanup; + + while ((sr_get_change_next(session, it, &oper, &old_value, &new_value)) == SR_ERR_OK) { + + const char *leaf = remove_substr(new_value->xpath, XPATH_SERVER"/cznic-resolver-knot:"); + printf("%s\n", new_value->xpath); + if (!strcmp(leaf, "auto-start")) + cfg->auto_start = new_value->data.bool_val; + else if (!strcmp(leaf, "auto-cache-gc")) + cfg->auto_cache_gc = new_value->data.bool_val; + else if (!strcmp(leaf, "kresd-instances")) + cfg->kresd_instances = new_value->data.uint8_val; + + sr_free_val(old_value); + sr_free_val(new_value); + } + + cleanup: + sr_free_change_iter(it); + } + else if(event == SR_EV_ABORT) + { + /* abortion actions */ + } + + return SR_ERR_OK; +} + +static int tls_sticket_secret_change_cb(sr_session_ctx_t *session, +const char *module_name, const char *xpath, sr_event_t event, +uint32_t request_id, void *private_data) +{ + if(event == SR_EV_CHANGE) + { + /* validation actions*/ + } + else if (event == SR_EV_DONE) + { + int ret = 0; + uv_loop_t *loop = the_worker->loop; + ret = tst_secret_timer_init(loop); + if (ret){ + kr_log_error("[sysrepo] failed to init tls session ticket secret"); + } + } + else if(event == SR_EV_ABORT) + { + /* abortion actions */ + } + + return SR_ERR_OK; +} + +static int cache_prefill_change_cb(sr_session_ctx_t *session, const char *module_name, const char *xpath, +sr_event_t event, uint32_t request_id, void *private_data) +{ + if(event == SR_EV_CHANGE) + { + /* validation actions*/ + } + else if (event == SR_EV_DONE) + { + int err = SR_ERR_OK; + sr_change_oper_t oper; + sr_val_t *old_value = NULL; + sr_val_t *new_value = NULL; + sr_change_iter_t *it = NULL; + + struct server_config *cfg = &the_worker->engine->watcher.config; + + err = sr_get_changes_iter(session, XPATH_CACHE_PREFILL "/*", &it); + if (err != SR_ERR_OK) goto cleanup; + + lua_State *L = the_worker->engine->L; + engine_cmd(L, "modules.load('prefill')",false); + lua_getglobal(L, "prefill.config"); + lua_newtable(L); + + while ((sr_get_change_next(session, it, &oper, &old_value, &new_value)) == SR_ERR_OK) { + + printf("%s\n", new_value->xpath); + char *leaf = strrchr(new_value->xpath, '/'); + + if (leaf && !strcmp(leaf, "/origin")) { + + } + if (leaf && !strcmp(leaf, "/url")) { + lua_pushstring(L, new_value->data.string_val); + lua_setfield(L, -2, "url"); + } + + if (leaf && !strcmp(leaf, "/ca-file")){ + lua_pushstring(L, new_value->data.string_val); + lua_setfield(L, -2, "ca_file"); + } + if (leaf && !strcmp(leaf, "/refresh-interval")) { + lua_pushnumber(L, new_value->data.uint32_val); + lua_setfield(L, -2, "interval"); + } + + sr_free_val(old_value); + sr_free_val(new_value); + } + + lua_setglobal(L, "."); + engine_pcall(L, 1); + + cleanup: + sr_free_change_iter(it); + } + else if(event == SR_EV_ABORT) + { + /* abortion actions */ + } + return SR_ERR_OK; +} + +/* RPC CALLBACKS */ + +/* Callback for kresd instances controll */ +static int rpc_instance_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + const char *leaf = remove_substr(path, XPATH_SERVER"/"); + + + return 0; +} + +static int rpc_resolver_start_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + int ret = SR_ERR_OK; + struct server_config cfg = the_worker->engine->watcher.config; + + kresd_instances_start(UNIT_START); + + if (cfg.auto_cache_gc) + cache_gc_ctl(UNIT_START); + + return ret; +} + +static int rpc_resolver_stop_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + int ret = SR_ERR_OK; + struct server_config cfg = the_worker->engine->watcher.config; + + kresd_instances_start(UNIT_STOP); + + cache_gc_ctl(UNIT_STOP); + + return ret; +} + +static int rpc_resolver_restart_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + int ret = SR_ERR_OK; + struct server_config cfg = the_worker->engine->watcher.config; + + kresd_instances_start(UNIT_RESTART); + + if (cfg.auto_cache_gc) + cache_gc_ctl(UNIT_RESTART); + + return ret; +} + +static int rpc_cache_gc_start_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + int ret = cache_gc_ctl(UNIT_START); + + return 0; +} + +static int rpc_cache_gc_stop_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + int ret = cache_gc_ctl(UNIT_STOP); + + return 0; +} + +static int rpc_cache_gc_restart_cb(sr_session_ctx_t *session, const char *path, +const sr_val_t *input, const size_t input_cnt, sr_event_t event, +uint32_t request_id, sr_val_t **output, size_t *output_cnt, void *private_data) +{ + int ret = cache_gc_ctl(UNIT_RESTART); + + return 0; +} + +int sysrepo_subscr_register(sr_session_ctx_t *session, sr_subscription_ctx_t **subscription) +{ + int err = SR_ERR_OK; + + /* CONFIG CHANGES */ + + err = sr_module_change_subscribe(session, YM_COMMON, XPATH_SERVER, + server_change_cb, NULL, 0, SR_SUBSCR_NO_THREAD|SR_SUBSCR_ENABLED|SR_SUBSCR_DONE_ONLY, subscription); + if (err != SR_ERR_OK) return err; + + err = sr_module_change_subscribe(session, YM_COMMON, XPATH_TST_SECRET, + tls_sticket_secret_change_cb, NULL, 0, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE|SR_SUBSCR_DONE_ONLY, subscription); + if (err != SR_ERR_OK) return err; + + err = sr_module_change_subscribe(session, YM_COMMON, XPATH_CACHE_PREFILL, + cache_prefill_change_cb, NULL, 0, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE|SR_SUBSCR_ENABLED|SR_SUBSCR_DONE_ONLY, subscription); + if (err != SR_ERR_OK) return err; + + /* RPC OPERATIONS */ + + err = sr_rpc_subscribe(session, XPATH_RPC_BASE":start", rpc_resolver_start_cb, NULL, 0, + SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != 0) return err; + + err = sr_rpc_subscribe(session, XPATH_RPC_BASE":stop", rpc_resolver_stop_cb, NULL, 0, + SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != 0) return err; + + err = sr_rpc_subscribe(session, XPATH_RPC_BASE":restart", rpc_resolver_restart_cb, NULL, 0, + SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != 0) return err; + + err = sr_rpc_subscribe(session, XPATH_GC"/start", rpc_cache_gc_start_cb, NULL, 0, + SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != 0) return err; + + err = sr_rpc_subscribe(session, XPATH_GC"/stop", rpc_cache_gc_stop_cb, NULL, 0, + SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != 0) return err; + + err = sr_rpc_subscribe(session, XPATH_GC"/restart", rpc_cache_gc_restart_cb, NULL, 0, + SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != 0) return err; + + /* STATE DATA */ + + err = sr_oper_get_items_subscribe(session, YM_COMMON, XPATH_SERVER, server_cb, NULL, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != SR_ERR_OK) return err; + + err = sr_oper_get_items_subscribe(session, YM_COMMON, XPATH_STATUS, server_status_cb, NULL, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != SR_ERR_OK) return err; + + err = sr_oper_get_items_subscribe(session, YM_COMMON, XPATH_INSTANCES, instance_status_cb, NULL, SR_SUBSCR_NO_THREAD|SR_SUBSCR_CTX_REUSE, subscription); + if (err != SR_ERR_OK) return err; + + return err; +} \ No newline at end of file diff --git a/utils/watcher/sr_subscriptions.h b/utils/watcher/sr_subscriptions.h new file mode 100644 index 000000000..aeaa98c08 --- /dev/null +++ b/utils/watcher/sr_subscriptions.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +int sysrepo_subscr_register(sr_session_ctx_t *session, sr_subscription_ctx_t **subscription); + +int set_tst_secret(const char *secret); + +int resolver_start(); diff --git a/utils/watcher/tls.c b/utils/watcher/tls.c new file mode 100644 index 000000000..96396bd28 --- /dev/null +++ b/utils/watcher/tls.c @@ -0,0 +1,1197 @@ +/* + * Copyright (C) 2016 American Civil Liberties Union (ACLU) + * 2016-2018 CZ.NIC, z.s.p.o + * + * Initial Author: Daniel Kahn Gillmor + * Ondřej Surý + * + * This program is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program. If not, see . + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "contrib/ucw/lib.h" +#include "contrib/base64.h" +#include "io.h" +#include "tls.h" +#include "worker.h" +#include "session.h" + +#define EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE (60*60*24*7) +#define GNUTLS_PIN_MIN_VERSION 0x030400 + +/** @internal Debugging facility. */ +#ifdef DEBUG +#define DEBUG_MSG(...) kr_log_verbose("[tls] " __VA_ARGS__) +#else +#define DEBUG_MSG(...) +#endif + +struct async_write_ctx { + uv_write_t write_req; + struct tls_common_ctx *t; + char buf[]; +}; + +static char const server_logstring[] = "tls"; +static char const client_logstring[] = "tls_client"; + +static int client_verify_certificate(gnutls_session_t tls_session); + +/** + * Set mandatory security settings from + * https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles-11#section-9 + * Performance optimizations are not implemented at the moment. + */ +static int kres_gnutls_set_priority(gnutls_session_t session) { + static const char * const priorities = + "NORMAL:" /* GnuTLS defaults */ + "-VERS-TLS1.0:-VERS-TLS1.1:" /* TLS 1.2 and higher */ + /* Some distros by default allow features that are considered + * too insecure nowadays, so let's disable them explicitly. */ + "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL"; + const char *errpos = NULL; + int err = gnutls_priority_set_direct(session, priorities, &errpos); + if (err != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n", + priorities, errpos - priorities, errpos, gnutls_strerror_name(err), err); + } + return err; +} + +static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len) +{ + struct tls_common_ctx *t = (struct tls_common_ctx *)h; + assert(t != NULL); + + ssize_t avail = t->nread - t->consumed; + DEBUG_MSG("[%s] pull wanted: %zu available: %zu\n", + t->client_side ? "tls_client" : "tls", len, avail); + if (t->nread <= t->consumed) { + errno = EAGAIN; + return -1; + } + + ssize_t transfer = MIN(avail, len); + memcpy(buf, t->buf + t->consumed, transfer); + t->consumed += transfer; + return transfer; +} + +static void on_write_complete(uv_write_t *req, int status) +{ + assert(req->data != NULL); + struct async_write_ctx *async_ctx = (struct async_write_ctx *)req->data; + struct tls_common_ctx *t = async_ctx->t; + assert(t->write_queue_size); + t->write_queue_size -= 1; + free(req->data); +} + +static bool stream_queue_is_empty(struct tls_common_ctx *t) +{ + return (t->write_queue_size == 0); +} + +static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt) +{ + struct tls_common_ctx *t = (struct tls_common_ctx *)h; + + if (t == NULL) { + errno = EFAULT; + return -1; + } + + if (iovcnt == 0) { + return 0; + } + + assert(t->session); + uv_stream_t *handle = (uv_stream_t *)session_get_handle(t->session); + assert(handle && handle->type == UV_TCP); + + /* + * This is a little bit complicated. There are two different writes: + * 1. Immediate, these don't need to own the buffered data and return immediately + * 2. Asynchronous, these need to own the buffers until the write completes + * In order to avoid copying the buffer, an immediate write is tried first if possible. + * If it isn't possible to write the data without queueing, an asynchronous write + * is created (with copied buffered data). + */ + + size_t total_len = 0; + uv_buf_t uv_buf[iovcnt]; + for (int i = 0; i < iovcnt; ++i) { + uv_buf[i].base = iov[i].iov_base; + uv_buf[i].len = iov[i].iov_len; + total_len += iov[i].iov_len; + } + + /* Try to perform the immediate write first to avoid copy */ + int ret = 0; + if (stream_queue_is_empty(t)) { + ret = uv_try_write(handle, uv_buf, iovcnt); + DEBUG_MSG("[%s] push %zu <%p> = %d\n", + t->client_side ? "tls_client" : "tls", total_len, h, ret); + /* from libuv documentation - + uv_try_write will return either: + > 0: number of bytes written (can be less than the supplied buffer size). + < 0: negative error code (UV_EAGAIN is returned if no data can be sent immediately). + */ + if (ret == total_len) { + /* All the data were buffered by libuv. + * Return. */ + return ret; + } + + if (ret < 0 && ret != UV_EAGAIN) { + /* uv_try_write() has returned error code other then UV_EAGAIN. + * Return. */ + ret = -1; + errno = EIO; + return ret; + } + /* Since we are here expression below is true + * (ret != total_len) && (ret >= 0 || ret == UV_EAGAIN) + * or the same + * (ret != total_len && ret >= 0) || (ret != total_len && ret == UV_EAGAIN) + * i.e. either occurs partial write or UV_EAGAIN. + * Proceed and copy data amount to owned memory and perform async write. + */ + if (ret == UV_EAGAIN) { + /* No data were buffered, so we must buffer all the data. */ + ret = 0; + } + } + + /* Fallback when the queue is full, and it's not possible to do an immediate write */ + char *p = malloc(sizeof(struct async_write_ctx) + total_len - ret); + if (p != NULL) { + struct async_write_ctx *async_ctx = (struct async_write_ctx *)p; + /* Save pointer to session tls context */ + async_ctx->t = t; + char *buf = async_ctx->buf; + /* Skip data written in the partial write */ + size_t to_skip = ret; + /* Copy the buffer into owned memory */ + size_t off = 0; + for (int i = 0; i < iovcnt; ++i) { + if (to_skip > 0) { + /* Ignore current buffer if it's all skipped */ + if (to_skip >= uv_buf[i].len) { + to_skip -= uv_buf[i].len; + continue; + } + /* Skip only part of the buffer */ + uv_buf[i].base += to_skip; + uv_buf[i].len -= to_skip; + to_skip = 0; + } + memcpy(buf + off, uv_buf[i].base, uv_buf[i].len); + off += uv_buf[i].len; + } + uv_buf[0].base = buf; + uv_buf[0].len = off; + + /* Create an asynchronous write request */ + uv_write_t *write_req = &async_ctx->write_req; + memset(write_req, 0, sizeof(uv_write_t)); + write_req->data = p; + + /* Perform an asynchronous write with a callback */ + if (uv_write(write_req, handle, uv_buf, 1, on_write_complete) == 0) { + ret = total_len; + t->write_queue_size += 1; + } else { + free(p); + errno = EIO; + ret = -1; + } + } else { + errno = ENOMEM; + ret = -1; + } + + DEBUG_MSG("[%s] queued %zu <%p> = %d\n", + t->client_side ? "tls_client" : "tls", total_len, h, ret); + + return ret; +} + +/** Perform TLS handshake and handle error codes according to the documentation. + * See See https://gnutls.org/manual/html_node/TLS-handshake.html#TLS-handshake + * The function returns kr_ok() or success or non fatal error, kr_error(EAGAIN) on blocking, or kr_error(EIO) on fatal error. + */ +static int tls_handshake(struct tls_common_ctx *ctx, tls_handshake_cb handshake_cb) { + struct session *session = ctx->session; + const char *logstring = ctx->client_side ? client_logstring : server_logstring; + + int err = gnutls_handshake(ctx->tls_session); + if (err == GNUTLS_E_SUCCESS) { + /* Handshake finished, return success */ + ctx->handshake_state = TLS_HS_DONE; + struct sockaddr *peer = session_get_peer(session); + kr_log_verbose("[%s] TLS handshake with %s has completed\n", + logstring, kr_straddr(peer)); + if (handshake_cb) { + if (handshake_cb(session, 0) != kr_ok()) { + return kr_error(EIO); + } + } + } else if (err == GNUTLS_E_AGAIN) { + return kr_error(EAGAIN); + } else if (gnutls_error_is_fatal(err)) { + /* Fatal errors, return error as it's not recoverable */ + kr_log_verbose("[%s] gnutls_handshake failed: %s (%d)\n", + logstring, + gnutls_strerror_name(err), err); + if (handshake_cb) { + handshake_cb(session, -1); + } + return kr_error(EIO); + } else if (err == GNUTLS_E_WARNING_ALERT_RECEIVED) { + /* Handle warning when in verbose mode */ + const char *alert_name = gnutls_alert_get_name(gnutls_alert_get(ctx->tls_session)); + if (alert_name != NULL) { + struct sockaddr *peer = session_get_peer(session); + kr_log_verbose("[%s] TLS alert from %s received: %s\n", + logstring, kr_straddr(peer), alert_name); + } + } + return kr_ok(); +} + + +struct tls_ctx_t *tls_new(struct worker_ctx *worker) +{ + assert(worker != NULL); + assert(worker->engine != NULL); + + struct network *net = &worker->engine->net; + if (!net->tls_credentials) { + net->tls_credentials = tls_get_ephemeral_credentials(worker->engine); + if (!net->tls_credentials) { + kr_log_error("[tls] X.509 credentials are missing, and ephemeral credentials failed; no TLS\n"); + return NULL; + } + kr_log_info("[tls] Using ephemeral TLS credentials:\n"); + tls_credentials_log_pins(net->tls_credentials); + } + + time_t now = time(NULL); + if (net->tls_credentials->valid_until != GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION) { + if (net->tls_credentials->ephemeral_servicename) { + /* ephemeral cert: refresh if due to expire within a week */ + if (now >= net->tls_credentials->valid_until - EPHEMERAL_CERT_EXPIRATION_SECONDS_RENEW_BEFORE) { + struct tls_credentials *newcreds = tls_get_ephemeral_credentials(worker->engine); + if (newcreds) { + tls_credentials_release(net->tls_credentials); + net->tls_credentials = newcreds; + kr_log_info("[tls] Renewed expiring ephemeral X.509 cert\n"); + } else { + kr_log_error("[tls] Failed to renew expiring ephemeral X.509 cert, using existing one\n"); + } + } + } else { + /* non-ephemeral cert: warn once when certificate expires */ + if (now >= net->tls_credentials->valid_until) { + kr_log_error("[tls] X.509 certificate has expired!\n"); + net->tls_credentials->valid_until = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION; + } + } + } + + struct tls_ctx_t *tls = calloc(1, sizeof(struct tls_ctx_t)); + if (tls == NULL) { + kr_log_error("[tls] failed to allocate TLS context\n"); + return NULL; + } + + int err = gnutls_init(&tls->c.tls_session, GNUTLS_SERVER | GNUTLS_NONBLOCK); + if (err != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] gnutls_init(): %s (%d)\n", gnutls_strerror_name(err), err); + tls_free(tls); + return NULL; + } + tls->credentials = tls_credentials_reserve(net->tls_credentials); + err = gnutls_credentials_set(tls->c.tls_session, GNUTLS_CRD_CERTIFICATE, + tls->credentials->credentials); + if (err != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] gnutls_credentials_set(): %s (%d)\n", gnutls_strerror_name(err), err); + tls_free(tls); + return NULL; + } + if (kres_gnutls_set_priority(tls->c.tls_session) != GNUTLS_E_SUCCESS) { + tls_free(tls); + return NULL; + } + + tls->c.worker = worker; + tls->c.client_side = false; + + gnutls_transport_set_pull_function(tls->c.tls_session, kres_gnutls_pull); + gnutls_transport_set_vec_push_function(tls->c.tls_session, kres_gnutls_vec_push); + gnutls_transport_set_ptr(tls->c.tls_session, tls); + + if (net->tls_session_ticket_ctx) { + tls_session_ticket_enable(net->tls_session_ticket_ctx, + tls->c.tls_session); + } + + return tls; +} + +void tls_close(struct tls_common_ctx *ctx) +{ + if (ctx == NULL || ctx->tls_session == NULL) { + return; + } + + assert(ctx->session); + + if (ctx->handshake_state == TLS_HS_DONE) { + const struct sockaddr *peer = session_get_peer(ctx->session); + kr_log_verbose("[%s] closing tls connection to `%s`\n", + ctx->client_side ? "tls_client" : "tls", + kr_straddr(peer)); + ctx->handshake_state = TLS_HS_CLOSING; + gnutls_bye(ctx->tls_session, GNUTLS_SHUT_RDWR); + } +} + +void tls_free(struct tls_ctx_t *tls) +{ + if (!tls) { + return; + } + + if (tls->c.tls_session) { + /* Don't terminate TLS connection, just tear it down */ + gnutls_deinit(tls->c.tls_session); + tls->c.tls_session = NULL; + } + + tls_credentials_release(tls->credentials); + free(tls); +} + +int tls_write(uv_write_t *req, uv_handle_t *handle, knot_pkt_t *pkt, uv_write_cb cb) +{ + if (!pkt || !handle || !handle->data) { + return kr_error(EINVAL); + } + + struct session *s = handle->data; + struct tls_common_ctx *tls_ctx = session_tls_get_common_ctx(s); + + assert (tls_ctx); + assert (session_flags(s)->outgoing == tls_ctx->client_side); + + const uint16_t pkt_size = htons(pkt->size); + const char *logstring = tls_ctx->client_side ? client_logstring : server_logstring; + gnutls_session_t tls_session = tls_ctx->tls_session; + + gnutls_record_cork(tls_session); + ssize_t count = 0; + if ((count = gnutls_record_send(tls_session, &pkt_size, sizeof(pkt_size)) < 0) || + (count = gnutls_record_send(tls_session, pkt->wire, pkt->size) < 0)) { + kr_log_error("[%s] gnutls_record_send failed: %s (%zd)\n", + logstring, gnutls_strerror_name(count), count); + return kr_error(EIO); + } + + const ssize_t submitted = sizeof(pkt_size) + pkt->size; + + int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT); + if (ret < 0) { + if (!gnutls_error_is_fatal(ret)) { + return kr_error(EAGAIN); + } else { + kr_log_error("[%s] gnutls_record_uncork failed: %s (%d)\n", + logstring, gnutls_strerror_name(ret), ret); + return kr_error(EIO); + } + } + + if (ret != submitted) { + kr_log_error("[%s] gnutls_record_uncork didn't send all data (%d of %zd)\n", + logstring, ret, submitted); + return kr_error(EIO); + } + + /* The data is now accepted in gnutls internal buffers, the message can be treated as sent */ + req->handle = (uv_stream_t *)handle; + cb(req, 0); + + return kr_ok(); +} + +ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nread) +{ + struct tls_common_ctx *tls_p = session_tls_get_common_ctx(s); + if (!tls_p) { + return kr_error(ENOSYS); + } + + assert(tls_p->session == s); + const bool ok = tls_p->recv_buf == buf && nread <= sizeof(tls_p->recv_buf); + if (!ok) { + assert(false); + /* don't risk overflowing the buffer if we have a mistake somewhere */ + return kr_error(EINVAL); + } + + const char *logstring = tls_p->client_side ? client_logstring : server_logstring; + + tls_p->buf = buf; + tls_p->nread = nread >= 0 ? nread : 0; + tls_p->consumed = 0; + + /* Ensure TLS handshake is performed before receiving data. + * See https://www.gnutls.org/manual/html_node/TLS-handshake.html */ + while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) { + int err = tls_handshake(tls_p, tls_p->handshake_cb); + if (err == kr_error(EAGAIN)) { + return 0; /* Wait for more data */ + } else if (err != kr_ok()) { + return err; + } + } + + /* See https://gnutls.org/manual/html_node/Data-transfer-and-termination.html#Data-transfer-and-termination */ + ssize_t submitted = 0; + uint8_t *wire_buf = session_wirebuf_get_free_start(s); + size_t wire_buf_size = session_wirebuf_get_free_size(s); + while (true) { + ssize_t count = gnutls_record_recv(tls_p->tls_session, wire_buf, wire_buf_size); + if (count == GNUTLS_E_AGAIN) { + if (tls_p->consumed == tls_p->nread) { + /* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */ + break; /* No more data available in this libuv buffer */ + } + continue; + } else if (count == GNUTLS_E_INTERRUPTED) { + continue; + } else if (count == GNUTLS_E_REHANDSHAKE) { + /* See https://www.gnutls.org/manual/html_node/Re_002dauthentication.html */ + struct sockaddr *peer = session_get_peer(s); + kr_log_verbose("[%s] TLS rehandshake with %s has started\n", + logstring, kr_straddr(peer)); + tls_set_hs_state(tls_p, TLS_HS_IN_PROGRESS); + int err = kr_ok(); + while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) { + err = tls_handshake(tls_p, tls_p->handshake_cb); + if (err == kr_error(EAGAIN)) { + break; + } else if (err != kr_ok()) { + return err; + } + } + if (err == kr_error(EAGAIN)) { + /* pull function is out of data */ + break; + } + /* There are can be data available, check it. */ + continue; + } else if (count < 0) { + kr_log_verbose("[%s] gnutls_record_recv failed: %s (%zd)\n", + logstring, gnutls_strerror_name(count), count); + return kr_error(EIO); + } else if (count == 0) { + break; + } + DEBUG_MSG("[%s] received %zd data\n", logstring, count); + wire_buf += count; + wire_buf_size -= count; + submitted += count; + if (wire_buf_size == 0 && tls_p->consumed != tls_p->nread) { + /* session buffer is full + * whereas not all the data were consumed */ + return kr_error(ENOSPC); + } + } + /* Here all data must be consumed. */ + if (tls_p->consumed != tls_p->nread) { + /* Something went wrong, better return error. + * This is most probably due to gnutls_record_recv() did not + * consume all available network data by calling kres_gnutls_pull(). + * TODO assess the need for buffering of data amount. + */ + return kr_error(ENOSPC); + } + return submitted; +} + +#if TLS_CAN_USE_PINS +/* + DNS-over-TLS Out of band key-pinned authentication profile uses the + same form of pins as HPKP: + + e.g. pin-sha256="FHkyLhvI0n70E47cJlRTamTrnYVcsYdjUGbr79CfAVI=" + + DNS-over-TLS OOB key-pins: https://tools.ietf.org/html/rfc7858#appendix-A + HPKP pin reference: https://tools.ietf.org/html/rfc7469#appendix-A +*/ +#define PINLEN ((((32) * 8 + 4)/6) + 3 + 1) + +/* Compute pin_sha256 for the certificate. + * It may be in raw format - just TLS_SHA256_RAW_LEN bytes without termination, + * or it may be a base64 0-terminated string requiring up to + * TLS_SHA256_BASE64_BUFLEN bytes. + * \return error code */ +static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len, bool raw) +{ + if (raw && outchar_len < TLS_SHA256_RAW_LEN) { + assert(false); + return kr_error(ENOSPC); + /* With !raw we have check inside base64_encode. */ + } + gnutls_pubkey_t key; + int err = gnutls_pubkey_init(&key); + if (err != GNUTLS_E_SUCCESS) return err; + + gnutls_datum_t datum = { .data = NULL, .size = 0 }; + err = gnutls_pubkey_import_x509(key, crt, 0); + if (err != GNUTLS_E_SUCCESS) goto leave; + + err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum); + if (err != GNUTLS_E_SUCCESS) goto leave; + + char raw_pin[TLS_SHA256_RAW_LEN]; /* TMP buffer if raw == false */ + err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size, + (raw ? outchar : raw_pin)); + if (err != GNUTLS_E_SUCCESS || raw/*success*/) + goto leave; + /* Convert to non-raw. */ + err = base64_encode((uint8_t *)raw_pin, sizeof(raw_pin), + (uint8_t *)outchar, outchar_len); + if (err >= 0 && err < outchar_len) { + err = GNUTLS_E_SUCCESS; + outchar[err] = '\0'; /* base64_encode() doesn't do it */ + } else if (err >= 0) { + assert(false); + err = kr_error(ENOSPC); /* base64 fits but '\0' doesn't */ + outchar[outchar_len - 1] = '\0'; + } +leave: + gnutls_free(datum.data); + gnutls_pubkey_deinit(key); + return err; +} + +void tls_credentials_log_pins(struct tls_credentials *tls_credentials) +{ + for (int index = 0;; index++) { + gnutls_x509_crt_t *certs = NULL; + unsigned int cert_count = 0; + int err = gnutls_certificate_get_x509_crt(tls_credentials->credentials, + index, &certs, &cert_count); + if (err != GNUTLS_E_SUCCESS) { + if (err != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) { + kr_log_error("[tls] could not get X.509 certificates (%d) %s\n", + err, gnutls_strerror_name(err)); + } + return; + } + + for (int i = 0; i < cert_count; i++) { + char pin[TLS_SHA256_BASE64_BUFLEN] = { 0 }; + err = get_oob_key_pin(certs[i], pin, sizeof(pin), false); + if (err != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n", + i, err, gnutls_strerror_name(err)); + } else { + kr_log_info("[tls] RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n", + i, pin); + } + gnutls_x509_crt_deinit(certs[i]); + } + gnutls_free(certs); + } +} +#else +void tls_credentials_log_pins(struct tls_credentials *tls_credentials) +{ + kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin; GnuTLS 3.4.0+ required\n"); +} +#endif + +static int str_replace(char **where_ptr, const char *with) +{ + char *copy = with ? strdup(with) : NULL; + if (with && !copy) { + return kr_error(ENOMEM); + } + + free(*where_ptr); + *where_ptr = copy; + return kr_ok(); +} + +static time_t _get_end_entity_expiration(gnutls_certificate_credentials_t creds) +{ + gnutls_datum_t data; + gnutls_x509_crt_t cert = NULL; + int err; + time_t ret = GNUTLS_X509_NO_WELL_DEFINED_EXPIRATION; + + if ((err = gnutls_certificate_get_crt_raw(creds, 0, 0, &data)) != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] failed to get cert to check expiration: (%d) %s\n", + err, gnutls_strerror_name(err)); + goto done; + } + if ((err = gnutls_x509_crt_init(&cert)) != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] failed to initialize cert: (%d) %s\n", + err, gnutls_strerror_name(err)); + goto done; + } + if ((err = gnutls_x509_crt_import(cert, &data, GNUTLS_X509_FMT_DER)) != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] failed to construct cert while checking expiration: (%d) %s\n", + err, gnutls_strerror_name(err)); + goto done; + } + + ret = gnutls_x509_crt_get_expiration_time (cert); + done: + /* do not free data; g_c_get_crt_raw() says to treat it as + * constant. */ + gnutls_x509_crt_deinit(cert); + return ret; +} + +int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key) +{ + if (!net) { + return kr_error(EINVAL); + } + + struct tls_credentials *tls_credentials = calloc(1, sizeof(*tls_credentials)); + if (tls_credentials == NULL) { + return kr_error(ENOMEM); + } + + int err = 0; + if ((err = gnutls_certificate_allocate_credentials(&tls_credentials->credentials)) != GNUTLS_E_SUCCESS) { + kr_log_error("[tls] gnutls_certificate_allocate_credentials() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + tls_credentials_free(tls_credentials); + return kr_error(ENOMEM); + } + if ((err = gnutls_certificate_set_x509_system_trust(tls_credentials->credentials)) < 0) { + if (err != GNUTLS_E_UNIMPLEMENTED_FEATURE) { + kr_log_error("[tls] warning: gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + tls_credentials_free(tls_credentials); + return err; + } + } + + if ((str_replace(&tls_credentials->tls_cert, tls_cert) != 0) || + (str_replace(&tls_credentials->tls_key, tls_key) != 0)) { + tls_credentials_free(tls_credentials); + return kr_error(ENOMEM); + } + + if ((err = gnutls_certificate_set_x509_key_file(tls_credentials->credentials, + tls_cert, tls_key, GNUTLS_X509_FMT_PEM)) != GNUTLS_E_SUCCESS) { + tls_credentials_free(tls_credentials); + kr_log_error("[tls] gnutls_certificate_set_x509_key_file(%s,%s) failed: %d (%s)\n", + tls_cert, tls_key, err, gnutls_strerror_name(err)); + return kr_error(EINVAL); + } + /* record the expiration date: */ + tls_credentials->valid_until = _get_end_entity_expiration(tls_credentials->credentials); + + /* Exchange the x509 credentials */ + struct tls_credentials *old_credentials = net->tls_credentials; + + /* Start using the new x509_credentials */ + net->tls_credentials = tls_credentials; + tls_credentials_log_pins(net->tls_credentials); + + if (old_credentials) { + err = tls_credentials_release(old_credentials); + if (err != kr_error(EBUSY)) { + return err; + } + } + + return kr_ok(); +} + +struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials) { + if (!tls_credentials) { + return NULL; + } + tls_credentials->count++; + return tls_credentials; +} + +int tls_credentials_release(struct tls_credentials *tls_credentials) { + if (!tls_credentials) { + return kr_error(EINVAL); + } + if (--tls_credentials->count < 0) { + tls_credentials_free(tls_credentials); + } else { + return kr_error(EBUSY); + } + return kr_ok(); +} + +void tls_credentials_free(struct tls_credentials *tls_credentials) { + if (!tls_credentials) { + return; + } + + if (tls_credentials->credentials) { + gnutls_certificate_free_credentials(tls_credentials->credentials); + } + if (tls_credentials->tls_cert) { + free(tls_credentials->tls_cert); + } + if (tls_credentials->tls_key) { + free(tls_credentials->tls_key); + } + if (tls_credentials->ephemeral_servicename) { + free(tls_credentials->ephemeral_servicename); + } + free(tls_credentials); +} + +void tls_client_param_unref(tls_client_param_t *entry) +{ + if (!entry) return; + assert(entry->refs); /* Well, we'd only leak memory. */ + --(entry->refs); + if (entry->refs) return; + + DEBUG_MSG("freeing TLS parameters %p\n", (void *)entry); + + for (int i = 0; i < entry->ca_files.len; ++i) { + free_const(entry->ca_files.at[i]); + } + array_clear(entry->ca_files); + + free_const(entry->hostname); + + for (int i = 0; i < entry->pins.len; ++i) { + free_const(entry->pins.at[i]); + } + array_clear(entry->pins); + + if (entry->credentials) { + gnutls_certificate_free_credentials(entry->credentials); + } + + if (entry->session_data.data) { + gnutls_free(entry->session_data.data); + } + + free(entry); +} +static int param_free(void **param, void *null) +{ + assert(param && *param); + tls_client_param_unref(*param); + return 0; +} +void tls_client_params_free(tls_client_params_t *params) +{ + if (!params) return; + trie_apply(params, param_free, NULL); + trie_free(params); +} + +tls_client_param_t * tls_client_param_new() +{ + tls_client_param_t *e = calloc(1, sizeof(*e)); + if (!e) { + assert(!ENOMEM); + return NULL; + } + /* Note: those array_t don't need further initialization. */ + e->refs = 1; + int ret = gnutls_certificate_allocate_credentials(&e->credentials); + if (ret != GNUTLS_E_SUCCESS) { + kr_log_error("[tls_client] error: gnutls_certificate_allocate_credentials() fails (%s)\n", + gnutls_strerror_name(ret)); + free(e); + return NULL; + } + gnutls_certificate_set_verify_function(e->credentials, client_verify_certificate); + return e; +} + +/** + * Convert an IP address and port number to binary key. + * + * \precond buffer \param key must have sufficient size + * \param addr[in] + * \param len[out] output length + * \param key[out] output buffer + */ +static bool construct_key(const union inaddr *addr, uint32_t *len, char *key) +{ + switch (addr->ip.sa_family) { + case AF_INET: + memcpy(key, &addr->ip4.sin_port, sizeof(addr->ip4.sin_port)); + memcpy(key + sizeof(addr->ip4.sin_port), &addr->ip4.sin_addr, + sizeof(addr->ip4.sin_addr)); + *len = sizeof(addr->ip4.sin_port) + sizeof(addr->ip4.sin_addr); + return true; + case AF_INET6: + memcpy(key, &addr->ip6.sin6_port, sizeof(addr->ip6.sin6_port)); + memcpy(key + sizeof(addr->ip6.sin6_port), &addr->ip6.sin6_addr, + sizeof(addr->ip6.sin6_addr)); + *len = sizeof(addr->ip6.sin6_port) + sizeof(addr->ip6.sin6_addr); + return true; + default: + assert(!EINVAL); + return false; + } +} +tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params, + const struct sockaddr *addr, bool do_insert) +{ + assert(params && addr); + /* We accept NULL for empty map; ensure the map exists if needed. */ + if (!*params) { + if (!do_insert) return NULL; + *params = trie_create(NULL); + if (!*params) { + assert(!ENOMEM); + return NULL; + } + } + /* Construct the key. */ + const union inaddr *ia = (const union inaddr *)addr; + char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)]; + uint32_t len; + if (!construct_key(ia, &len, key)) + return NULL; + /* Get the entry. */ + return (tls_client_param_t **) + (do_insert ? trie_get_ins : trie_get_try)(*params, key, len); +} + +int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr) +{ + const union inaddr *ia = (const union inaddr *)addr; + char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)]; + uint32_t len; + if (!construct_key(ia, &len, key)) + return kr_error(EINVAL); + trie_val_t param_ptr; + int ret = trie_del(params, key, len, ¶m_ptr); + if (ret) + return kr_error(ret); + tls_client_param_unref(param_ptr); + return kr_ok(); +} + +/** + * Verify that at least one certificate in the certificate chain matches + * at least one certificate pin in the non-empty params->pins array. + * \returns GNUTLS_E_SUCCESS if pin matches, any other value is an error + */ +static int client_verify_pin(const unsigned int cert_list_size, + const gnutls_datum_t *cert_list, + tls_client_param_t *params) +{ + assert(params->pins.len > 0); +#if TLS_CAN_USE_PINS + for (int i = 0; i < cert_list_size; i++) { + gnutls_x509_crt_t cert; + int ret = gnutls_x509_crt_init(&cert); + if (ret != GNUTLS_E_SUCCESS) { + return ret; + } + + ret = gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER); + if (ret != GNUTLS_E_SUCCESS) { + gnutls_x509_crt_deinit(cert); + return ret; + } + + #ifdef DEBUG + if (VERBOSE_STATUS) { + char pin_base64[TLS_SHA256_BASE64_BUFLEN]; + /* DEBUG: additionally compute and print the base64 pin. + * Not very efficient, but that's OK for DEBUG. */ + ret = get_oob_key_pin(cert, pin_base64, sizeof(pin_base64), false); + if (ret == GNUTLS_E_SUCCESS) { + DEBUG_MSG("[tls_client] received pin: %s\n", pin_base64); + } else { + DEBUG_MSG("[tls_client] failed to convert received pin\n"); + /* Now we hope that `ret` below can't differ. */ + } + } + #endif + char cert_pin[TLS_SHA256_RAW_LEN]; + /* Get raw pin and compare. */ + ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), true); + gnutls_x509_crt_deinit(cert); + if (ret != GNUTLS_E_SUCCESS) { + return ret; + } + for (size_t j = 0; j < params->pins.len; ++j) { + const uint8_t *pin = params->pins.at[j]; + if (memcmp(cert_pin, pin, TLS_SHA256_RAW_LEN) != 0) + continue; /* mismatch */ + DEBUG_MSG("[tls_client] matched a configured pin no. %zd\n", j); + return GNUTLS_E_SUCCESS; + } + DEBUG_MSG("[tls_client] none of %zd configured pin(s) matched\n", + params->pins.len); + } + + kr_log_error("[tls_client] no pin matched: %zu pins * %d certificates\n", + params->pins.len, cert_list_size); + return GNUTLS_E_CERTIFICATE_ERROR; + +#else /* TLS_CAN_USE_PINS */ + kr_log_error("[tls_client] internal inconsistency: TLS_CAN_USE_PINS\n"); + assert(false); + return GNUTLS_E_CERTIFICATE_ERROR; +#endif +} + +/** + * Verify that \param tls_session contains a valid X.509 certificate chain + * with given hostname. + * + * \returns GNUTLS_E_SUCCESS if certificate chain is valid, any other value is an error + */ +static int client_verify_certchain(gnutls_session_t tls_session, const char *hostname) +{ + if (!hostname) { + kr_log_error("[tls_client] internal config inconsistency: no hostname set\n"); + assert(false); + return GNUTLS_E_CERTIFICATE_ERROR; + } + + unsigned int status; + int ret = gnutls_certificate_verify_peers3(tls_session, hostname, &status); + if ((ret == GNUTLS_E_SUCCESS) && (status == 0)) { + return GNUTLS_E_SUCCESS; + } + + if (ret == GNUTLS_E_SUCCESS) { + gnutls_datum_t msg; + ret = gnutls_certificate_verification_status_print( + status, gnutls_certificate_type_get(tls_session), &msg, 0); + if (ret == GNUTLS_E_SUCCESS) { + kr_log_error("[tls_client] failed to verify peer certificate: " + "%s\n", msg.data); + gnutls_free(msg.data); + } else { + kr_log_error("[tls_client] failed to verify peer certificate: " + "unable to print reason: %s (%s)\n", + gnutls_strerror(ret), gnutls_strerror_name(ret)); + } /* gnutls_certificate_verification_status_print end */ + } else { + kr_log_error("[tls_client] failed to verify peer certificate: " + "gnutls_certificate_verify_peers3 error: %s (%s)\n", + gnutls_strerror(ret), gnutls_strerror_name(ret)); + } /* gnutls_certificate_verify_peers3 end */ + return GNUTLS_E_CERTIFICATE_ERROR; +} + +/** + * Verify that actual TLS security parameters of \param tls_session + * match requirements provided by user in tls_session->params. + * \returns GNUTLS_E_SUCCESS if requirements were met, any other value is an error + */ +static int client_verify_certificate(gnutls_session_t tls_session) +{ + struct tls_client_ctx_t *ctx = gnutls_session_get_ptr(tls_session); + assert(ctx->params != NULL); + + if (ctx->params->insecure) { + return GNUTLS_E_SUCCESS; + } + + gnutls_certificate_type_t cert_type = gnutls_certificate_type_get(tls_session); + if (cert_type != GNUTLS_CRT_X509) { + kr_log_error("[tls_client] invalid certificate type %i has been received\n", + cert_type); + return GNUTLS_E_CERTIFICATE_ERROR; + } + unsigned int cert_list_size = 0; + const gnutls_datum_t *cert_list = + gnutls_certificate_get_peers(tls_session, &cert_list_size); + if (cert_list == NULL || cert_list_size == 0) { + kr_log_error("[tls_client] empty certificate list\n"); + return GNUTLS_E_CERTIFICATE_ERROR; + } + + if (ctx->params->pins.len > 0) + /* check hash of the certificate but ignore everything else */ + return client_verify_pin(cert_list_size, cert_list, ctx->params); + else + return client_verify_certchain(ctx->c.tls_session, ctx->params->hostname); +} + +struct tls_client_ctx_t *tls_client_ctx_new(tls_client_param_t *entry, + struct worker_ctx *worker) +{ + struct tls_client_ctx_t *ctx = calloc(1, sizeof (struct tls_client_ctx_t)); + if (!ctx) { + return NULL; + } + unsigned int flags = GNUTLS_CLIENT | GNUTLS_NONBLOCK +#ifdef GNUTLS_ENABLE_FALSE_START + | GNUTLS_ENABLE_FALSE_START +#endif + ; + int ret = gnutls_init(&ctx->c.tls_session, flags); + if (ret != GNUTLS_E_SUCCESS) { + tls_client_ctx_free(ctx); + return NULL; + } + + ret = kres_gnutls_set_priority(ctx->c.tls_session); + if (ret != GNUTLS_E_SUCCESS) { + tls_client_ctx_free(ctx); + return NULL; + } + + /* Must take a reference on parameters as the credentials are owned by it + * and must not be freed while the session is active. */ + ++(entry->refs); + ctx->params = entry; + + ret = gnutls_credentials_set(ctx->c.tls_session, GNUTLS_CRD_CERTIFICATE, + entry->credentials); + if (ret == GNUTLS_E_SUCCESS && entry->hostname) { + ret = gnutls_server_name_set(ctx->c.tls_session, GNUTLS_NAME_DNS, + entry->hostname, strlen(entry->hostname)); + kr_log_verbose("[tls_client] set hostname, ret = %d\n", ret); + } else if (!entry->hostname) { + kr_log_verbose("[tls_client] no hostname\n"); + } + if (ret != GNUTLS_E_SUCCESS) { + tls_client_ctx_free(ctx); + return NULL; + } + + ctx->c.worker = worker; + ctx->c.client_side = true; + + gnutls_transport_set_pull_function(ctx->c.tls_session, kres_gnutls_pull); + gnutls_transport_set_vec_push_function(ctx->c.tls_session, kres_gnutls_vec_push); + gnutls_transport_set_ptr(ctx->c.tls_session, ctx); + return ctx; +} + +void tls_client_ctx_free(struct tls_client_ctx_t *ctx) +{ + if (ctx == NULL) { + return; + } + + if (ctx->c.tls_session != NULL) { + gnutls_deinit(ctx->c.tls_session); + ctx->c.tls_session = NULL; + } + + /* Must decrease the refcount for referenced parameters */ + tls_client_param_unref(ctx->params); + + free (ctx); +} + +int tls_pull_timeout_func(gnutls_transport_ptr_t h, unsigned int ms) +{ + struct tls_common_ctx *t = (struct tls_common_ctx *)h; + assert(t != NULL); + ssize_t avail = t->nread - t->consumed; + DEBUG_MSG("[%s] timeout check: available: %zu\n", + t->client_side ? "tls_client" : "tls", avail); + if (avail <= 0) { + errno = EAGAIN; + return -1; + } + return avail; +} + +int tls_client_connect_start(struct tls_client_ctx_t *client_ctx, + struct session *session, + tls_handshake_cb handshake_cb) +{ + if (session == NULL || client_ctx == NULL) { + return kr_error(EINVAL); + } + + assert(session_flags(session)->outgoing && session_get_handle(session)->type == UV_TCP); + + struct tls_common_ctx *ctx = &client_ctx->c; + + gnutls_session_set_ptr(ctx->tls_session, client_ctx); + gnutls_handshake_set_timeout(ctx->tls_session, ctx->worker->engine->net.tcp.tls_handshake_timeout); + gnutls_transport_set_pull_timeout_function(ctx->tls_session, tls_pull_timeout_func); + session_tls_set_client_ctx(session, client_ctx); + ctx->handshake_cb = handshake_cb; + ctx->handshake_state = TLS_HS_IN_PROGRESS; + ctx->session = session; + + tls_client_param_t *tls_params = client_ctx->params; + if (tls_params->session_data.data != NULL) { + gnutls_session_set_data(ctx->tls_session, tls_params->session_data.data, + tls_params->session_data.size); + } + + /* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */ + while (ctx->handshake_state <= TLS_HS_IN_PROGRESS) { + int ret = tls_handshake(ctx, handshake_cb); + if (ret != kr_ok()) { + return ret; + } + } + return kr_ok(); +} + +tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx) +{ + return ctx->handshake_state; +} + +int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state) +{ + if (state >= TLS_HS_LAST) { + return kr_error(EINVAL); + } + ctx->handshake_state = state; + return kr_ok(); +} + +int tls_client_ctx_set_session(struct tls_client_ctx_t *ctx, struct session *session) +{ + if (!ctx) { + return kr_error(EINVAL); + } + ctx->c.session = session; + return kr_ok(); +} + +#undef DEBUG_MSG diff --git a/utils/watcher/tls.h b/utils/watcher/tls.h new file mode 100644 index 000000000..aa37df313 --- /dev/null +++ b/utils/watcher/tls.h @@ -0,0 +1,242 @@ +/* Copyright (C) 2016 American Civil Liberties Union (ACLU) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +#pragma once + +#include +#include +#include +#include "lib/defines.h" +#include "lib/generic/array.h" +#include "lib/generic/trie.h" +#include "lib/utils.h" + +#define MAX_TLS_PADDING KR_EDNS_PAYLOAD +#define TLS_MAX_UNCORK_RETRIES 100 + +/* rfc 5476, 7.3 - handshake Protocol overview + * https://tools.ietf.org/html/rfc5246#page-33 + * Message flow for a full handshake (only mandatory messages) + * ClientHello --------> + ServerHello + <-------- ServerHelloDone + ClientKeyExchange + Finished --------> + <-------- Finished + * + * See also https://blog.cloudflare.com/keyless-ssl-the-nitty-gritty-technical-details/ + * So it takes 2 RTT. + * As we use session tickets, there are additional messages, add one RTT mode. + */ + #define TLS_MAX_HANDSHAKE_TIME (KR_CONN_RTT_MAX * 3) + +/** Transport session (opaque). */ +struct session; + +struct tls_ctx_t; +struct tls_client_ctx_t; +struct tls_credentials { + int count; + char *tls_cert; + char *tls_key; + gnutls_certificate_credentials_t credentials; + time_t valid_until; + char *ephemeral_servicename; +}; + + +#define TLS_SHA256_RAW_LEN 32 /* gnutls_hash_get_len(GNUTLS_DIG_SHA256) */ +/** Required buffer length for pin_sha256, including the zero terminator. */ +#define TLS_SHA256_BASE64_BUFLEN (((TLS_SHA256_RAW_LEN * 8 + 4) / 6) + 3 + 1) + +#if GNUTLS_VERSION_NUMBER >= 0x030400 + #define TLS_CAN_USE_PINS 1 +#else + #define TLS_CAN_USE_PINS 0 +#endif + + +/** TLS authentication parameters for a single address-port pair. */ +typedef struct { + uint32_t refs; /**< Reference count; consider TLS sessions in progress. */ + bool insecure; /**< Use no authentication. */ + const char *hostname; /**< Server name for SNI and certificate check, lowercased. */ + array_t(const char *) ca_files; /**< Paths to certificate files; not really used. */ + array_t(const uint8_t *) pins; /**< Certificate pins as raw unterminated strings.*/ + gnutls_certificate_credentials_t credentials; /**< CA creds. in gnutls format. */ + gnutls_datum_t session_data; /**< Session-resumption data gets stored here. */ +} tls_client_param_t; +/** Holds configuration for TLS authentication for all potential servers. + * Special case: NULL pointer also means empty. */ +typedef trie_t tls_client_params_t; + +/** Get a pointer-to-pointer to TLS auth params. + * If it didn't exist, it returns NULL (if !do_insert) or pointer to NULL. */ +tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params, + const struct sockaddr *addr, bool do_insert); + +/** Get a pointer to TLS auth params or NULL. */ +static inline tls_client_param_t * + tls_client_param_get(tls_client_params_t *params, const struct sockaddr *addr) +{ + tls_client_param_t **pe = tls_client_param_getptr(¶ms, addr, false); + return pe ? *pe : NULL; +} + +/** Allocate and initialize the structure (with ->ref = 1). */ +tls_client_param_t * tls_client_param_new(); +/** Reference-counted free(); any inside data is freed alongside. */ +void tls_client_param_unref(tls_client_param_t *entry); + +int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr); +/** Free TLS authentication parameters. */ +void tls_client_params_free(tls_client_params_t *params); + + +struct worker_ctx; +struct qr_task; +struct network; +struct engine; + +typedef enum tls_client_hs_state { + TLS_HS_NOT_STARTED = 0, + TLS_HS_IN_PROGRESS, + TLS_HS_DONE, + TLS_HS_CLOSING, + TLS_HS_LAST +} tls_hs_state_t; + +typedef int (*tls_handshake_cb) (struct session *session, int status); + + +struct tls_common_ctx { + bool client_side; + gnutls_session_t tls_session; + tls_hs_state_t handshake_state; + struct session *session; + /* for reading from the network */ + const uint8_t *buf; + ssize_t nread; + ssize_t consumed; + uint8_t recv_buf[16384]; + tls_handshake_cb handshake_cb; + struct worker_ctx *worker; + size_t write_queue_size; +}; + +struct tls_ctx_t { + /* + * Since pointer to tls_ctx_t needs to be casted + * to tls_ctx_common in some functions, + * this field must be always at first position + */ + struct tls_common_ctx c; + struct tls_credentials *credentials; +}; + +struct tls_client_ctx_t { + /* + * Since pointer to tls_client_ctx_t needs to be casted + * to tls_ctx_common in some functions, + * this field must be always at first position + */ + struct tls_common_ctx c; + tls_client_param_t *params; /**< It's reference-counted. */ +}; + +/*! Create an empty TLS context in query context */ +struct tls_ctx_t* tls_new(struct worker_ctx *worker); + +/*! Close a TLS context (call gnutls_bye()) */ +void tls_close(struct tls_common_ctx *ctx); + +/*! Release a TLS context */ +void tls_free(struct tls_ctx_t* tls); + +/*! Push new data to TLS context for sending */ +int tls_write(uv_write_t *req, uv_handle_t* handle, knot_pkt_t * pkt, uv_write_cb cb); + +/*! Unwrap incoming data from a TLS stream and pass them to TCP session. + * @return the number of newly-completed requests (>=0) or an error code + */ +ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nread); + +/*! Set TLS certificate and key from files. */ +int tls_certificate_set(struct network *net, const char *tls_cert, const char *tls_key); + +/*! Borrow TLS credentials for context. */ +struct tls_credentials *tls_credentials_reserve(struct tls_credentials *tls_credentials); + +/*! Release TLS credentials for context (decrements refcount or frees). */ +int tls_credentials_release(struct tls_credentials *tls_credentials); + +/*! Free TLS credentials, must not be called if it holds positive refcount. */ +void tls_credentials_free(struct tls_credentials *tls_credentials); + +/*! Log DNS-over-TLS OOB key-pin form of current credentials: + * https://tools.ietf.org/html/rfc7858#appendix-A */ +void tls_credentials_log_pins(struct tls_credentials *tls_credentials); + +/*! Generate new ephemeral TLS credentials. */ +struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine); + +/*! Get TLS handshake state. */ +tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx); + +/*! Set TLS handshake state. */ +int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state); + + +/*! Allocate new client TLS context */ +struct tls_client_ctx_t *tls_client_ctx_new(tls_client_param_t *entry, + struct worker_ctx *worker); + +/*! Free client TLS context */ +void tls_client_ctx_free(struct tls_client_ctx_t *ctx); + +int tls_client_connect_start(struct tls_client_ctx_t *client_ctx, + struct session *session, + tls_handshake_cb handshake_cb); + +int tls_client_ctx_set_session(struct tls_client_ctx_t *ctx, struct session *session); + + +/* Session tickets, server side. Implementation in ./tls_session_ticket-srv.c */ + +/*! Opaque struct used by tls_session_ticket_* functions. */ +struct tls_session_ticket_ctx; + +/*! Suggested maximum reasonable secret length. */ +#define TLS_SESSION_TICKET_SECRET_MAX_LEN 1024 + +/*! Create a session ticket context and initialize it (secret gets copied inside). + * + * Passing zero-length secret implies using a random key, i.e. not synchronized + * between multiple instances. + * + * Beware that knowledge of the secret (if nonempty) breaks forward secrecy, + * so you should rotate the secret regularly and securely erase all past secrets. + * With TLS < 1.3 it's probably too risky to set nonempty secret. + */ +struct tls_session_ticket_ctx * tls_session_ticket_ctx_create( + uv_loop_t *loop, const char *secret, size_t secret_len); + +/*! Try to enable session tickets for a server session. */ +void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session); + +/*! Free all resources of the session ticket context. NULL is accepted as well. */ +void tls_session_ticket_ctx_destroy(struct tls_session_ticket_ctx *ctx); + diff --git a/utils/watcher/tls_ephemeral_credentials.c b/utils/watcher/tls_ephemeral_credentials.c new file mode 100644 index 000000000..0cc17b94c --- /dev/null +++ b/utils/watcher/tls_ephemeral_credentials.c @@ -0,0 +1,249 @@ +/* + * Copyright (C) 2016 American Civil Liberties Union (ACLU) + * Copyright (C) 2016-2017 CZ.NIC, z.s.p.o. + * + * Initial Author: Daniel Kahn Gillmor + * + * This program is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation, either version 3 of the License, or (at your option) + * any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along with + * this program. If not, see . + */ + +#include +#include +#include +#include +#include + +#include "worker.h" +#include "tls.h" + +#define EPHEMERAL_PRIVKEY_FILENAME "ephemeral_key.pem" +#define INVALID_HOSTNAME "dns-over-tls.invalid" +#define EPHEMERAL_CERT_EXPIRATION_SECONDS (60*60*24*90) + +/* This is an attempt to grab an exclusive, advisory, non-blocking + * lock based on a filename. At the moment it's POSIX-only, but it + * should be abstract enough of an interface to make an implementation + * for non-posix systems if anyone cares. */ +typedef int lock_t; +static bool _lock_is_invalid(lock_t lock) +{ + return lock == -1; +} +/* a blocking lock on a given filename */ +static lock_t _lock_filename(const char *fname) +{ + lock_t lockfd = open(fname, O_RDONLY|O_CREAT, 0400); + if (lockfd == -1) + return lockfd; + /* this should be a non-blocking lock */ + if (flock(lockfd, LOCK_EX | LOCK_NB) != 0) { + close(lockfd); + return -1; + } + return lockfd; /* for cleanup later */ +} +static void _lock_unlock(lock_t *lock, const char *fname) +{ + if (lock && !_lock_is_invalid(*lock)) { + flock(*lock, LOCK_UN); + close(*lock); + *lock = -1; + unlink(fname); /* ignore errors */ + } +} + +static gnutls_x509_privkey_t get_ephemeral_privkey () +{ + gnutls_x509_privkey_t privkey = NULL; + int err; + gnutls_datum_t data = { .data = NULL, .size = 0 }; + lock_t lock; + int datafd = -1; + + /* Take a lock to ensure that two daemons started concurrently + * with a shared cache don't both create the same privkey: */ + lock = _lock_filename(EPHEMERAL_PRIVKEY_FILENAME ".lock"); + if (_lock_is_invalid(lock)) { + kr_log_error("[tls] unable to lock lockfile " EPHEMERAL_PRIVKEY_FILENAME ".lock\n"); + goto done; + } + + if ((err = gnutls_x509_privkey_init (&privkey)) < 0) { + kr_log_error("[tls] gnutls_x509_privkey_init() failed: %d (%s)\n", + err, gnutls_strerror_name(err)); + goto done; + } + + /* read from cache file (we assume that we've chdir'ed + * already, so we're just looking for the file in the + * cachedir. */ + datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_RDONLY); + if (datafd != -1) { + struct stat stat; + ssize_t bytes_read; + if (fstat(datafd, &stat)) { + kr_log_error("[tls] unable to stat ephemeral private key " EPHEMERAL_PRIVKEY_FILENAME "\n"); + goto bad_data; + } + data.data = gnutls_malloc(stat.st_size); + if (data.data == NULL) { + kr_log_error("[tls] unable to allocate memory for reading ephemeral private key\n"); + goto bad_data; + } + data.size = stat.st_size; + bytes_read = read(datafd, data.data, stat.st_size); + if (bytes_read != stat.st_size) { + kr_log_error("[tls] unable to read ephemeral private key\n"); + goto bad_data; + } + if ((err = gnutls_x509_privkey_import (privkey, &data, GNUTLS_X509_FMT_PEM)) < 0) { + kr_log_error("[tls] gnutls_x509_privkey_import() failed: %d (%s)\n", + err, gnutls_strerror_name(err)); + /* goto bad_data; */ + bad_data: + close(datafd); + datafd = -1; + } + if (data.data != NULL) { + gnutls_free(data.data); + data.data = NULL; + } + } + if (datafd == -1) { + /* if loading failed, then generate ... */ +#if GNUTLS_VERSION_NUMBER >= 0x030500 + if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_ECDSA, GNUTLS_CURVE_TO_BITS(GNUTLS_ECC_CURVE_SECP256R1), 0)) < 0) { +#else + if ((err = gnutls_x509_privkey_generate(privkey, GNUTLS_PK_RSA, gnutls_sec_param_to_pk_bits(GNUTLS_PK_RSA, GNUTLS_SEC_PARAM_MEDIUM), 0)) < 0) { +#endif + kr_log_error("[tls] gnutls_x509_privkey_init() failed: %d (%s)\n", + err, gnutls_strerror_name(err)); + gnutls_x509_privkey_deinit(privkey); + goto done; + } + /* ... and save */ + kr_log_info("[tls] Stashing ephemeral private key in " EPHEMERAL_PRIVKEY_FILENAME "\n"); + if ((err = gnutls_x509_privkey_export2(privkey, GNUTLS_X509_FMT_PEM, &data)) < 0) { + kr_log_error("[tls] gnutls_x509_privkey_export2() failed: %d (%s), not storing\n", + err, gnutls_strerror_name(err)); + } else { + datafd = open(EPHEMERAL_PRIVKEY_FILENAME, O_WRONLY|O_CREAT, 0600); + if (datafd == -1) { + kr_log_error("[tls] failed to open " EPHEMERAL_PRIVKEY_FILENAME " to store the ephemeral key\n"); + } else { + ssize_t bytes_written; + bytes_written = write(datafd, data.data, data.size); + if (bytes_written != data.size) + kr_log_error("[tls] failed to write %d octets to " + EPHEMERAL_PRIVKEY_FILENAME + " (%zd written)\n", + data.size, bytes_written); + } + } + } + done: + _lock_unlock(&lock, EPHEMERAL_PRIVKEY_FILENAME ".lock"); + if (datafd != -1) { + close(datafd); + } + if (data.data != NULL) { + gnutls_free(data.data); + } + return privkey; +} + +static gnutls_x509_crt_t get_ephemeral_cert(gnutls_x509_privkey_t privkey, const char *servicename, time_t invalid_before, time_t valid_until) +{ + gnutls_x509_crt_t cert = NULL; + int err; + /* need a random buffer of bytes */ + uint8_t serial[16]; + gnutls_rnd(GNUTLS_RND_NONCE, serial, sizeof(serial)); + /* clear the left-most bit to avoid signedness confusion: */ + serial[0] &= 0x8f; + size_t namelen = strlen(servicename); + +#define gtx(fn, ...) \ + if ((err = fn ( __VA_ARGS__ )) != GNUTLS_E_SUCCESS) { \ + kr_log_error("[tls] " #fn "() failed: %d (%s)\n", \ + err, gnutls_strerror_name(err)); \ + goto bad; } + + gtx(gnutls_x509_crt_init, &cert); + gtx(gnutls_x509_crt_set_activation_time, cert, invalid_before); + gtx(gnutls_x509_crt_set_ca_status, cert, 0); + gtx(gnutls_x509_crt_set_expiration_time, cert, valid_until); + gtx(gnutls_x509_crt_set_key, cert, privkey); + gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_CLIENT, 0); + gtx(gnutls_x509_crt_set_key_purpose_oid, cert, GNUTLS_KP_TLS_WWW_SERVER, 0); + gtx(gnutls_x509_crt_set_key_usage, cert, GNUTLS_KEY_DIGITAL_SIGNATURE); + gtx(gnutls_x509_crt_set_serial, cert, serial, sizeof(serial)); + gtx(gnutls_x509_crt_set_subject_alt_name, cert, GNUTLS_SAN_DNSNAME, servicename, namelen, GNUTLS_FSAN_SET); + gtx(gnutls_x509_crt_set_dn_by_oid,cert, GNUTLS_OID_X520_COMMON_NAME, 0, servicename, namelen); + gtx(gnutls_x509_crt_set_version, cert, 3); + gtx(gnutls_x509_crt_sign2,cert, cert, privkey, GNUTLS_DIG_SHA256, 0); /* self-sign, since it doesn't look like we can just stub-sign */ +#undef gtx + + return cert; +bad: + gnutls_x509_crt_deinit(cert); + return NULL; +} + +struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine) +{ + struct tls_credentials *creds = NULL; + gnutls_x509_privkey_t privkey = NULL; + gnutls_x509_crt_t cert = NULL; + int err; + time_t now = time(NULL); + + creds = calloc(1, sizeof(*creds)); + if (!creds) { + kr_log_error("[tls] failed to allocate memory for ephemeral credentials\n"); + return NULL; + } + if ((err = gnutls_certificate_allocate_credentials(&(creds->credentials))) < 0) { + kr_log_error("[tls] failed to allocate memory for ephemeral credentials\n"); + goto failure; + } + + creds->valid_until = now + EPHEMERAL_CERT_EXPIRATION_SECONDS; + creds->ephemeral_servicename = strdup(engine_get_hostname(engine)); + if (creds->ephemeral_servicename == NULL) { + kr_log_error("[tls] could not get server's hostname, using '" INVALID_HOSTNAME "' instead\n"); + if ((creds->ephemeral_servicename = strdup(INVALID_HOSTNAME)) == NULL) { + kr_log_error("[tls] failed to allocate memory for ephemeral credentials\n"); + goto failure; + } + } + if ((privkey = get_ephemeral_privkey()) == NULL) { + goto failure; + } + if ((cert = get_ephemeral_cert(privkey, creds->ephemeral_servicename, now - 60*15, creds->valid_until)) == NULL) { + goto failure; + } + if ((err = gnutls_certificate_set_x509_key(creds->credentials, &cert, 1, privkey)) < 0) { + kr_log_error("[tls] failed to set up ephemeral credentials\n"); + goto failure; + } + gnutls_x509_privkey_deinit(privkey); + gnutls_x509_crt_deinit(cert); + return creds; + failure: + gnutls_x509_privkey_deinit(privkey); + gnutls_x509_crt_deinit(cert); + tls_credentials_free(creds); + return NULL; +} diff --git a/utils/watcher/tls_session_ticket-srv.c b/utils/watcher/tls_session_ticket-srv.c new file mode 100644 index 000000000..ff1471b7d --- /dev/null +++ b/utils/watcher/tls_session_ticket-srv.c @@ -0,0 +1,262 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "lib/utils.h" + +/* Style: "local/static" identifiers are usually named tst_* */ + +/** The number of seconds between synchronized rotation of TLS session ticket key. */ +#define TST_KEY_LIFETIME 4096 + +/** Value from gnutls:lib/ext/session_ticket.c + * Beware: changing this needs to change the hashing implementation. */ +#define SESSION_KEY_SIZE 64 + +/** Compile-time support for setting the secret. */ +/* This is not secure with TLS <= 1.2 but TLS 1.3 and secure configuration + * is not available in GnuTLS yet. See https://gitlab.com/gnutls/gnutls/issues/477 +#ifndef TLS_SESSION_RESUMPTION_SYNC + #define TLS_SESSION_RESUMPTION_SYNC (GNUTLS_VERSION_NUMBER >= 0x030603) +#endif +*/ + +#if GNUTLS_VERSION_NUMBER < 0x030400 + /* It's of little use anyway. We may get the secret through lua, + * which creates a copy outside of our control. */ + #define gnutls_memset memset +#endif + +#ifdef GNUTLS_DIG_SHA3_512 + #define TST_HASH GNUTLS_DIG_SHA3_512 +#else + #define TST_HASH abort() +#endif + +/** Fields are internal to tst_key_* functions. */ +typedef struct tls_session_ticket_ctx { + uv_timer_t timer; /**< timer for rotation of the key */ + unsigned char key[SESSION_KEY_SIZE]; /**< the key itself */ + bool has_secret; /**< false -> key is random for each epoch */ + uint16_t hash_len; /**< length of `hash_data` */ + char hash_data[]; /**< data to hash to obtain `key`; + * it's `time_t epoch` and then the secret string */ +} tst_ctx_t; + +/** Check invariants, based on gnutls version. */ +static bool tst_key_invariants(void) +{ + static int result = 0; /*< cache for multiple invocations */ + if (result) return result > 0; + bool ok = true; + #if TLS_SESSION_RESUMPTION_SYNC + /* SHA3-512 output size may never change, but let's check it anyway :-) */ + ok = ok && gnutls_hash_get_len(TST_HASH) == SESSION_KEY_SIZE; + #endif + /* The ticket key size might change in a different gnutls version. */ + gnutls_datum_t key = { 0, 0 }; + ok = ok && gnutls_session_ticket_key_generate(&key) == 0 + && key.size == SESSION_KEY_SIZE; + free(key.data); + result = ok ? 1 : -1; + return ok; +} + +/** Create the internal structures and copy the secret. Beware: secret must be kept secure. */ +static tst_ctx_t * tst_key_create(const char *secret, size_t secret_len, uv_loop_t *loop) +{ + const size_t hash_len = sizeof(time_t) + secret_len; + if (secret_len && + (!secret || hash_len > UINT16_MAX || hash_len < secret_len)) { + assert(!EINVAL); + return NULL; + /* reasonable secret_len is best enforced in config API */ + } + if (!tst_key_invariants()) { + assert(!EFAULT); + return NULL; + } + #if !TLS_SESSION_RESUMPTION_SYNC + if (secret_len) { + kr_log_error("[tls] session ticket: secrets were not enabled at compile-time (your GnuTLS version is not supported)\n"); + return NULL; /* ENOTSUP */ + } + #endif + + tst_ctx_t *ctx = malloc(sizeof(*ctx) + hash_len); /* can be slightly longer */ + if (!ctx) return NULL; + ctx->has_secret = secret_len > 0; + ctx->hash_len = hash_len; + if (secret_len) { + memcpy(ctx->hash_data + sizeof(time_t), secret, secret_len); + } + + if (uv_timer_init(loop, &ctx->timer) != 0) { + free(ctx); + return NULL; + } + ctx->timer.data = ctx; + return ctx; +} + +/** Random variant of secret rotation: generate into key_tmp and copy. */ +static int tst_key_get_random(tst_ctx_t *ctx) +{ + gnutls_datum_t key_tmp = { NULL, 0 }; + int err = gnutls_session_ticket_key_generate(&key_tmp); + if (err) return kr_error(err); + if (key_tmp.size != SESSION_KEY_SIZE) { + assert(!EFAULT); + return kr_error(EFAULT); + } + memcpy(ctx->key, key_tmp.data, SESSION_KEY_SIZE); + gnutls_memset(key_tmp.data, 0, SESSION_KEY_SIZE); + free(key_tmp.data); + return kr_ok(); +} + +/** Recompute the session ticket key, if epoch has changed or forced. */ +static int tst_key_update(tst_ctx_t *ctx, time_t epoch, bool force_update) +{ + if (!ctx || ctx->hash_len < sizeof(epoch)) { + assert(!EINVAL); + return kr_error(EINVAL); + } + /* documented limitation: time_t and endianess must match + * on instances sharing a secret */ + if (!force_update && memcmp(ctx->hash_data, &epoch, sizeof(epoch)) == 0) { + return kr_ok(); /* we are up to date */ + } + memcpy(ctx->hash_data, &epoch, sizeof(epoch)); + + if (!ctx->has_secret) { + return tst_key_get_random(ctx); + } + /* Otherwise, deterministic variant of secret rotation, if supported. */ + #if !TLS_SESSION_RESUMPTION_SYNC + assert(false); + return kr_error(ENOTSUP); + #else + int err = gnutls_hash_fast(TST_HASH, ctx->hash_data, + ctx->hash_len, ctx->key); + return err == 0 ? kr_ok() : kr_error(err); + #endif +} + +/** Free all resources of the key (securely). */ +static void tst_key_destroy(uv_handle_t *timer) +{ + assert(timer); + tst_ctx_t *ctx = timer->data; + assert(ctx); + gnutls_memset(ctx, 0, offsetof(tst_ctx_t, hash_data) + ctx->hash_len); + free(ctx); +} + +static void tst_key_check(uv_timer_t *timer, bool force_update); +static void tst_timer_callback(uv_timer_t *timer) +{ + tst_key_check(timer, false); +} + +/** Update the ST key if needed and reschedule itself via the timer. */ +static void tst_key_check(uv_timer_t *timer, bool force_update) +{ + tst_ctx_t *stst = (tst_ctx_t *)timer->data; + /* Compute the current epoch. */ + struct timeval now; + if (gettimeofday(&now, NULL)) { + kr_log_error("[tls] session ticket: gettimeofday failed, %s\n", + strerror(errno)); + return; + } + uv_update_time(timer->loop); /* to have sync. between real and mono time */ + const time_t epoch = now.tv_sec / TST_KEY_LIFETIME; + /* Update the key; new sessions will fetch it from the location. + * Old ones hopefully can't get broken by that; documentation + * for gnutls_session_ticket_enable_server() doesn't say. */ + int err = tst_key_update(stst, epoch, force_update); + if (err) { + assert(err != kr_error(EINVAL)); + kr_log_error("[tls] session ticket: failed rotation, err = %d\n", err); + } + /* Reschedule. */ + const time_t tv_sec_next = (epoch + 1) * TST_KEY_LIFETIME; + const uint64_t ms_until_second = 1000 - (now.tv_usec + 501) / 1000; + const uint64_t remain_ms = (tv_sec_next - now.tv_sec - 1) * (uint64_t)1000 + + ms_until_second + 1; + /* ^ +1 because we don't want to wake up half a millisecond before the epoch! */ + assert(remain_ms < (TST_KEY_LIFETIME + 1 /*rounding tolerance*/) * 1000); + kr_log_verbose("[tls] session ticket: epoch %"PRIu64 + ", scheduling rotation check in %"PRIu64" ms\n", + (uint64_t)epoch, remain_ms); + err = uv_timer_start(timer, &tst_timer_callback, remain_ms, 0); + if (err) { + assert(false); + kr_log_error("[tls] session ticket: failed to schedule, err = %d\n", err); + } +} + +/* Implementation for prototypes from ./tls.h */ + +void tls_session_ticket_enable(struct tls_session_ticket_ctx *ctx, gnutls_session_t session) +{ + assert(ctx && session); + const gnutls_datum_t gd = { + .size = SESSION_KEY_SIZE, + .data = ctx->key, + }; + int err = gnutls_session_ticket_enable_server(session, &gd); + if (err) { + kr_log_error("[tls] failed to enable session tickets: %s (%d)\n", + gnutls_strerror_name(err), err); + /* but continue without tickets */ + } +} + +tst_ctx_t * tls_session_ticket_ctx_create(uv_loop_t *loop, const char *secret, + size_t secret_len) +{ + assert(loop && (!secret_len || secret)); + #if GNUTLS_VERSION_NUMBER < 0x030500 + /* We would need different SESSION_KEY_SIZE; avoid assert. */ + return NULL; + #endif + tst_ctx_t *ctx = tst_key_create(secret, secret_len, loop); + if (ctx) { + tst_key_check(&ctx->timer, true); + } + return ctx; +} + +void tls_session_ticket_ctx_destroy(tst_ctx_t *ctx) +{ + if (ctx == NULL) { + return; + } + uv_close((uv_handle_t *)&ctx->timer, &tst_key_destroy); +} + diff --git a/utils/watcher/udp_queue.c b/utils/watcher/udp_queue.c new file mode 100644 index 000000000..833c6fcf0 --- /dev/null +++ b/utils/watcher/udp_queue.c @@ -0,0 +1,170 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "kresconfig.h" +#include "udp_queue.h" + +#include "session.h" +#include "worker.h" +#include "lib/generic/array.h" +#include "lib/utils.h" + +struct qr_task; + +#include +#include + + +#if !ENABLE_SENDMMSG +int udp_queue_init_global(uv_loop_t *loop) +{ + return 0; +} +/* Appease the linker in case this unused call isn't optimized out. */ +void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task) +{ + abort(); +} +#else + +/* LATER: it might be useful to have this configurable during runtime, + * but the structures below would have to change a little (broken up). */ +#define UDP_QUEUE_LEN 64 + +/** A queue of up to UDP_QUEUE_LEN messages, meant for the same socket. */ +typedef struct { + int len; /**< The number of messages in the queue: 0..UDP_QUEUE_LEN */ + struct mmsghdr msgvec[UDP_QUEUE_LEN]; /**< Parameter for sendmmsg() */ + struct { + struct qr_task *task; /**< Links for completion callbacks. */ + struct iovec msg_iov[1]; /**< storage for .msgvec[i].msg_iov */ + } items[UDP_QUEUE_LEN]; +} udp_queue_t; + +static udp_queue_t * udp_queue_create() +{ + udp_queue_t *q = calloc(1, sizeof(*q)); + for (int i = 0; i < UDP_QUEUE_LEN; ++i) { + struct msghdr *mhi = &q->msgvec[i].msg_hdr; + /* These shall remain always the same. */ + mhi->msg_iov = q->items[i].msg_iov; + mhi->msg_iovlen = 1; + /* msg_name and msg_namelen will be per-call, + * and the rest is OK to remain zeroed all the time. */ + } + return q; +} + +/** Global state for udp_queue_*. Note: we never free the pointed-to memory. */ +struct { + /** Singleton map: fd -> udp_queue_t, as a simple array of pointers. */ + udp_queue_t **udp_queues; + int udp_queues_len; + + /** List of FD numbers that might have a non-empty queue. */ + array_t(int) waiting_fds; + + uv_check_t check_handle; +} static state = {0}; + +/** Empty the given queue. The queue is assumed to exist (but may be empty). */ +static void udp_queue_send(int fd) +{ + udp_queue_t *const q = state.udp_queues[fd]; + if (!q->len) return; + int sent_len = sendmmsg(fd, q->msgvec, q->len, 0); + /* ATM we don't really do anything about failures. */ + int err = sent_len < 0 ? errno : EAGAIN /* unknown error, really */; + if (unlikely(sent_len != q->len)) { + if (err != EWOULDBLOCK) { + kr_log_error("ERROR: udp sendmmsg() sent %d / %d; %s\n", + sent_len, q->len, strerror(err)); + } else { + const uint64_t stamp_now = kr_now(); + static uint64_t stamp_last = 0; + if (stamp_now > stamp_last + 60*1000) { + kr_log_info("WARNING: dropped UDP reply packet(s) due to network overload (reported at most once per minute)\n"); + stamp_last = stamp_now; + } + } + } + for (int i = 0; i < q->len; ++i) { + qr_task_on_send(q->items[i].task, NULL, i < sent_len ? 0 : err); + worker_task_unref(q->items[i].task); + } + q->len = 0; +} + +/** Periodical callback to send all queued packets. */ +static void udp_queue_check(uv_check_t *handle) +{ + for (int i = 0; i < state.waiting_fds.len; ++i) { + udp_queue_send(state.waiting_fds.at[i]); + } + state.waiting_fds.len = 0; +} + +int udp_queue_init_global(uv_loop_t *loop) +{ + int ret = uv_check_init(loop, &state.check_handle); + if (!ret) ret = uv_check_start(&state.check_handle, udp_queue_check); + return ret; +} + +void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task) +{ + if (fd < 0) { + kr_log_error("ERROR: called udp_queue_push(fd = %d, ...)\n", fd); + abort(); + } + worker_task_ref(task); + /* Get a valid correct queue. */ + if (fd >= state.udp_queues_len) { + const int new_len = fd + 1; + state.udp_queues = realloc(state.udp_queues, + sizeof(state.udp_queues[0]) * new_len); + if (!state.udp_queues) abort(); + memset(state.udp_queues + state.udp_queues_len, 0, + sizeof(state.udp_queues[0]) * (new_len - state.udp_queues_len)); + state.udp_queues_len = new_len; + } + if (unlikely(state.udp_queues[fd] == NULL)) + state.udp_queues[fd] = udp_queue_create(); + udp_queue_t *const q = state.udp_queues[fd]; + + /* Append to the queue */ + struct sockaddr *sa = (struct sockaddr *)/*const-cast*/req->qsource.addr; + q->msgvec[q->len].msg_hdr.msg_name = sa; + q->msgvec[q->len].msg_hdr.msg_namelen = kr_sockaddr_len(sa); + q->items[q->len].task = task; + q->items[q->len].msg_iov[0] = (struct iovec){ + .iov_base = req->answer->wire, + .iov_len = req->answer->size, + }; + if (q->len == 0) + array_push(state.waiting_fds, fd); + ++(q->len); + + if (q->len >= UDP_QUEUE_LEN) { + assert(q->len == UDP_QUEUE_LEN); + udp_queue_send(fd); + /* We don't need to search state.waiting_fds; + * anyway, it's more efficient to let the hook do that. */ + } +} + +#endif + diff --git a/utils/watcher/udp_queue.h b/utils/watcher/udp_queue.h new file mode 100644 index 000000000..c1730c056 --- /dev/null +++ b/utils/watcher/udp_queue.h @@ -0,0 +1,28 @@ +/* Copyright (C) 2019 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include +struct kr_request; +struct qr_task; + +/** Initialize the global state for udp_queue. */ +int udp_queue_init_global(uv_loop_t *loop); + +/** Send req->answer via UDP, possibly not immediately. */ +void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task); + diff --git a/utils/watcher/watcher.c b/utils/watcher/watcher.c new file mode 100644 index 000000000..e173daca3 --- /dev/null +++ b/utils/watcher/watcher.c @@ -0,0 +1,143 @@ +#include +#include +#include + +#include +#include + +#include "kresconfig.h" +#include "lib/utils.h" +#include "modules/sysrepo/common/sysrepo.h" + +#include "worker.h" +#include "watcher.h" +#include "sr_subscriptions.h" + +/* 12 hours interval */ +#define TST_SECRET_CYCLE 12*60*60*1000 + + +/* default configuration */ +struct server_config default_config = { + .auto_start = false, + .auto_cache_gc = true, + .kresd_instances = 1 +}; + +static void tst_secret_check(uv_timer_t *timer, bool timer_update); +static void tst_timer_callback(uv_timer_t *timer) +{ + tst_secret_check(timer, true); +} + +static tst_secret_ctx_t * tst_secret_create(uv_loop_t *loop) +{ + struct tst_secret_ctx *ctx = malloc(sizeof(*ctx)); + if (!ctx) return NULL; + + if (uv_timer_init(loop, &ctx->timer) != 0) { + free(ctx); + return NULL; + } + ctx->timer.data = ctx; + return ctx; +} + +static void tst_secret_check(uv_timer_t *timer, bool timer_update) +{ + int ret = 0; + uv_update_time(timer->loop); + + if(timer_update) { + uint8_t *base64; + gnutls_datum_t key_tmp = { NULL, 0 }; + ret = gnutls_session_ticket_key_generate(&key_tmp); + if (ret){ + kr_log_error("[watcher] failed to generate tls sticket secret, %s", strerror(ret)); + return; + } + + int32_t len = base64_encode_alloc((uint8_t *)key_tmp.data, sizeof key_tmp.data, &base64); + if (len < 0) { + kr_log_error("[watcher] failed to encode tls sticket secret in base64"); + return; + } + + base64[len-1] = '\0'; + char *secret = (char*) base64; + + kr_log_info("[watcher] generated new secret for tls session ticket\n"); + ret = set_tst_secret(secret); + + free(key_tmp.data); + free(secret); + } + uv_timer_start(timer, &tst_timer_callback, TST_SECRET_CYCLE, 0); +} + +static tst_secret_ctx_t * tst_secret_ctx_create(uv_loop_t *loop, bool timer_update) +{ + assert(loop); + tst_secret_ctx_t *ctx = tst_secret_create(loop); + if (ctx) { + tst_secret_check(&ctx->timer, timer_update); + } + kr_log_info("[watcher] new context for tls_sticket_secret created\n"); + return ctx; +} + +static void tst_secret_destroy(uv_handle_t *timer) +{ + assert(timer); + struct tst_secret_ctx *ctx = timer->data; + assert(ctx); + free(ctx); +} + + +static void tst_secret_timer_destroy(tst_secret_ctx_t *ctx) +{ + if (ctx == NULL) { + return; + } + uv_close((uv_handle_t *)&ctx->timer, &tst_secret_destroy); +} + +int tst_secret_timer_init(uv_loop_t *loop) +{ + tst_secret_ctx_t *tst_ctx = the_worker->engine->watcher.tst_secret; + + tst_secret_timer_destroy(tst_ctx); + + tst_ctx = tst_secret_ctx_create(loop, false); + the_worker->engine->watcher.tst_secret = tst_ctx; + if (the_worker->engine->watcher.tst_secret == NULL) { + kr_log_error("[watcher] failed to create tls session ticket secret context"); + return 1; + } + return 0; +} + +void watcher_init(struct watcher_context *watcher, uv_loop_t *loop) +{ + assert(watcher != NULL); + if (watcher != NULL) { + watcher->loop = loop; + watcher->config = default_config; + + /* Init sysrepo context */ + watcher->sysrepo = sysrepo_ctx_init(); + + /* Init timer for tls session ticket secret generation */ + watcher->tst_secret = tst_secret_ctx_create(loop, true); + } +} + +void watcher_deinit(struct watcher_context *watcher) +{ + assert(watcher); + if (watcher != NULL) { + sysrepo_ctx_deinit(watcher->sysrepo); + tst_secret_timer_destroy(watcher->tst_secret); + } +} \ No newline at end of file diff --git a/utils/watcher/watcher.h b/utils/watcher/watcher.h new file mode 100644 index 000000000..2122e1275 --- /dev/null +++ b/utils/watcher/watcher.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "modules/sysrepo/common/sysrepo.h" + + +typedef struct server_config { + bool auto_start; + bool auto_cache_gc; + uint8_t kresd_instances; +}; + +typedef struct tst_secret_ctx { + uv_timer_t timer; +} tst_secret_ctx_t; + +typedef struct sdbus_ctx { + +} sdbus_ctx_t; + +struct watcher_context { + uv_loop_t *loop; + sysrepo_uv_ctx_t *sysrepo; + sdbus_ctx_t *sdbus; + tst_secret_ctx_t *tst_secret; + struct server_config config; +}; + +void watcher_init(struct watcher_context *watcher, uv_loop_t *loop); + +void watcher_deinit(struct watcher_context *watcher); \ No newline at end of file diff --git a/utils/watcher/worker.c b/utils/watcher/worker.c new file mode 100644 index 000000000..542f08cf5 --- /dev/null +++ b/utils/watcher/worker.c @@ -0,0 +1,2060 @@ +/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#include "kresconfig.h" +#include "worker.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__GLIBC__) && defined(_GNU_SOURCE) +#include +#endif +#include +#include +#include +#include + +#include "bindings/api.h" +#include "engine.h" +#include "io.h" +#include "session.h" +#include "tls.h" +#include "udp_queue.h" +#include "zimport.h" +#include "lib/layer.h" +#include "lib/utils.h" + + +/* Magic defaults for the worker. */ +#ifndef MP_FREELIST_SIZE +# ifdef __clang_analyzer__ +# define MP_FREELIST_SIZE 0 +# else +# define MP_FREELIST_SIZE 64 /**< Maximum length of the worker mempool freelist */ +# endif +#endif +#ifndef QUERY_RATE_THRESHOLD +#define QUERY_RATE_THRESHOLD (2 * MP_FREELIST_SIZE) /**< Nr of parallel queries considered as high rate */ +#endif +#ifndef MAX_PIPELINED +#define MAX_PIPELINED 100 +#endif + +#define VERBOSE_MSG(qry, ...) QRVERBOSE(qry, "wrkr", __VA_ARGS__) + +/** Client request state. */ +struct request_ctx +{ + struct kr_request req; + + struct { + /** Requestor's address; separate because of UDP session "sharing". */ + union inaddr addr; + /** NULL if the request didn't come over network. */ + struct session *session; + } source; + + struct worker_ctx *worker; + struct qr_task *task; +}; + +/** Query resolution task. */ +struct qr_task +{ + struct request_ctx *ctx; + knot_pkt_t *pktbuf; + qr_tasklist_t waiting; + struct session *pending[MAX_PENDING]; + uint16_t pending_count; + uint16_t addrlist_count; + uint16_t addrlist_turn; + uint16_t timeouts; + uint16_t iter_count; + struct sockaddr *addrlist; + uint32_t refs; + bool finished : 1; + bool leading : 1; + uint64_t creation_time; +}; + + +/* Convenience macros */ +#define qr_task_ref(task) \ + do { ++(task)->refs; } while(0) +#define qr_task_unref(task) \ + do { \ + if (task) \ + assert((task)->refs > 0); \ + if ((task) && --(task)->refs == 0) \ + qr_task_free((task)); \ + } while (0) + +/** @internal get key for tcp session + * @note kr_straddr() return pointer to static string + */ +#define tcpsess_key(addr) kr_straddr(addr) + +/* Forward decls */ +static void qr_task_free(struct qr_task *task); +static int qr_task_step(struct qr_task *task, + const struct sockaddr *packet_source, + knot_pkt_t *packet); +static int qr_task_send(struct qr_task *task, struct session *session, + const struct sockaddr *addr, knot_pkt_t *pkt); +static int qr_task_finalize(struct qr_task *task, int state); +static void qr_task_complete(struct qr_task *task); +static struct session* worker_find_tcp_connected(struct worker_ctx *worker, + const struct sockaddr *addr); +static int worker_add_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr *addr, + struct session *session); +static struct session* worker_find_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr *addr); +static void on_tcp_connect_timeout(uv_timer_t *timer); + +struct worker_ctx the_worker_value; /**< Static allocation is suitable for the singleton. */ +struct worker_ctx *the_worker = NULL; + +/*! @internal Create a UDP/TCP handle for an outgoing AF_INET* connection. + * socktype is SOCK_* */ +static uv_handle_t *ioreq_spawn(struct worker_ctx *worker, + int socktype, sa_family_t family, bool has_tls) +{ + bool precond = (socktype == SOCK_DGRAM || socktype == SOCK_STREAM) + && (family == AF_INET || family == AF_INET6); + if (!precond) { + assert(false); + kr_log_verbose("[work] ioreq_spawn: pre-condition failed\n"); + return NULL; + } + + /* Create connection for iterative query */ + uv_handle_t *handle = malloc(socktype == SOCK_DGRAM + ? sizeof(uv_udp_t) : sizeof(uv_tcp_t)); + if (!handle) { + return NULL; + } + int ret = io_create(worker->loop, handle, socktype, family, has_tls); + if (ret) { + if (ret == UV_EMFILE) { + worker->too_many_open = true; + worker->rconcurrent_highwatermark = worker->stats.rconcurrent; + } + free(handle); + return NULL; + } + + /* Bind to outgoing address, according to IP v4/v6. */ + union inaddr *addr; + if (family == AF_INET) { + addr = (union inaddr *)&worker->out_addr4; + } else { + addr = (union inaddr *)&worker->out_addr6; + } + if (addr->ip.sa_family != AF_UNSPEC) { + assert(addr->ip.sa_family == family); + if (socktype == SOCK_DGRAM) { + uv_udp_t *udp = (uv_udp_t *)handle; + ret = uv_udp_bind(udp, &addr->ip, 0); + } else if (socktype == SOCK_STREAM){ + uv_tcp_t *tcp = (uv_tcp_t *)handle; + ret = uv_tcp_bind(tcp, &addr->ip, 0); + } + } + + if (ret != 0) { + io_deinit(handle); + free(handle); + return NULL; + } + + /* Set current handle as a subrequest type. */ + struct session *session = handle->data; + session_flags(session)->outgoing = true; + /* Connect or issue query datagram */ + return handle; +} + +static void ioreq_kill_pending(struct qr_task *task) +{ + for (uint16_t i = 0; i < task->pending_count; ++i) { + session_kill_ioreq(task->pending[i], task); + } + task->pending_count = 0; +} + +/** @cond This memory layout is internal to mempool.c, use only for debugging. */ +#if defined(__SANITIZE_ADDRESS__) +struct mempool_chunk { + struct mempool_chunk *next; + size_t size; +}; +static void mp_poison(struct mempool *mp, bool poison) +{ + if (!poison) { /* @note mempool is part of the first chunk, unpoison it first */ + kr_asan_unpoison(mp, sizeof(*mp)); + } + struct mempool_chunk *chunk = mp->state.last[0]; + void *chunk_off = (uint8_t *)chunk - chunk->size; + if (poison) { + kr_asan_poison(chunk_off, chunk->size); + } else { + kr_asan_unpoison(chunk_off, chunk->size); + } +} +#else +#define mp_poison(mp, enable) +#endif +/** @endcond */ + +/** Get a mempool. (Recycle if possible.) */ +static inline struct mempool *pool_borrow(struct worker_ctx *worker) +{ + struct mempool *mp = NULL; + if (worker->pool_mp.len > 0) { + mp = array_tail(worker->pool_mp); + array_pop(worker->pool_mp); + mp_poison(mp, 0); + } else { /* No mempool on the freelist, create new one */ + mp = mp_new (4 * CPU_PAGE_SIZE); + } + return mp; +} + +/** Return a mempool. (Cache them up to some count.) */ +static inline void pool_release(struct worker_ctx *worker, struct mempool *mp) +{ + if (worker->pool_mp.len < MP_FREELIST_SIZE) { + mp_flush(mp); + array_push(worker->pool_mp, mp); + mp_poison(mp, 1); + } else { + mp_delete(mp); + } +} + +/** Create a key for an outgoing subrequest: qname, qclass, qtype. + * @param key Destination buffer for key size, MUST be SUBREQ_KEY_LEN or larger. + * @return key length if successful or an error + */ +static const size_t SUBREQ_KEY_LEN = KR_RRKEY_LEN; +static int subreq_key(char *dst, knot_pkt_t *pkt) +{ + assert(pkt); + return kr_rrkey(dst, knot_pkt_qclass(pkt), knot_pkt_qname(pkt), + knot_pkt_qtype(pkt), knot_pkt_qtype(pkt)); +} + +/** Create and initialize a request_ctx (on a fresh mempool). + * + * handle and addr point to the source of the request, and they are NULL + * in case the request didn't come from network. + */ +static struct request_ctx *request_create(struct worker_ctx *worker, + struct session *session, + const struct sockaddr *peer, + uint32_t uid) +{ + knot_mm_t pool = { + .ctx = pool_borrow(worker), + .alloc = (knot_mm_alloc_t) mp_alloc + }; + + /* Create request context */ + struct request_ctx *ctx = mm_alloc(&pool, sizeof(*ctx)); + if (!ctx) { + pool_release(worker, pool.ctx); + return NULL; + } + + memset(ctx, 0, sizeof(*ctx)); + + /* TODO Relocate pool to struct request */ + ctx->worker = worker; + if (session) { + assert(session_flags(session)->outgoing == false); + } + ctx->source.session = session; + + struct kr_request *req = &ctx->req; + req->pool = pool; + req->vars_ref = LUA_NOREF; + req->uid = uid; + if (session) { + /* We assume the session will be alive during the whole life of the request. */ + req->qsource.dst_addr = session_get_sockname(session); + req->qsource.flags.tcp = session_get_handle(session)->type == UV_TCP; + req->qsource.flags.tls = session_flags(session)->has_tls; + /* We need to store a copy of peer address. */ + memcpy(&ctx->source.addr.ip, peer, kr_sockaddr_len(peer)); + req->qsource.addr = &ctx->source.addr.ip; + } + + worker->stats.rconcurrent += 1; + + return ctx; +} + +/** More initialization, related to the particular incoming query/packet. */ +static int request_start(struct request_ctx *ctx, knot_pkt_t *query) +{ + assert(query && ctx); + size_t answer_max = KNOT_WIRE_MIN_PKTSIZE; + struct kr_request *req = &ctx->req; + + /* source.session can be empty if request was generated by kresd itself */ + struct session *s = ctx->source.session; + if (!s || session_get_handle(s)->type == UV_TCP) { + answer_max = KNOT_WIRE_MAX_PKTSIZE; + } else if (knot_pkt_has_edns(query)) { /* EDNS */ + answer_max = MAX(knot_edns_get_payload(query->opt_rr), + KNOT_WIRE_MIN_PKTSIZE); + } + req->qsource.size = query->size; + if (knot_pkt_has_tsig(query)) { + req->qsource.size += query->tsig_wire.len; + } + + knot_pkt_t *answer = knot_pkt_new(NULL, answer_max, &req->pool); + if (!answer) { /* Failed to allocate answer */ + return kr_error(ENOMEM); + } + + knot_pkt_t *pkt = knot_pkt_new(NULL, req->qsource.size, &req->pool); + if (!pkt) { + return kr_error(ENOMEM); + } + + int ret = knot_pkt_copy(pkt, query); + if (ret != KNOT_EOK && ret != KNOT_ETRAIL) { + return kr_error(ENOMEM); + } + req->qsource.packet = pkt; + + /* Start resolution */ + struct worker_ctx *worker = ctx->worker; + struct engine *engine = worker->engine; + kr_resolve_begin(req, &engine->resolver, answer); + worker->stats.queries += 1; + /* Throttle outbound queries only when high pressure */ + if (worker->stats.concurrent < QUERY_RATE_THRESHOLD) { + req->options.NO_THROTTLE = true; + } + return kr_ok(); +} + +static void request_free(struct request_ctx *ctx) +{ + struct worker_ctx *worker = ctx->worker; + /* Dereference any Lua vars table if exists */ + if (ctx->req.vars_ref != LUA_NOREF) { + lua_State *L = worker->engine->L; + /* Get worker variables table */ + lua_rawgeti(L, LUA_REGISTRYINDEX, worker->vars_table_ref); + /* Get next free element (position 0) and store it under current reference (forming a list) */ + lua_rawgeti(L, -1, 0); + lua_rawseti(L, -2, ctx->req.vars_ref); + /* Set current reference as the next free element */ + lua_pushinteger(L, ctx->req.vars_ref); + lua_rawseti(L, -2, 0); + lua_pop(L, 1); + ctx->req.vars_ref = LUA_NOREF; + } + /* Return mempool to ring or free it if it's full */ + pool_release(worker, ctx->req.pool.ctx); + /* @note The 'task' is invalidated from now on. */ + worker->stats.rconcurrent -= 1; +} + +static struct qr_task *qr_task_create(struct request_ctx *ctx) +{ + /* How much can client handle? */ + struct engine *engine = ctx->worker->engine; + size_t pktbuf_max = KR_EDNS_PAYLOAD; + if (engine->resolver.opt_rr) { + pktbuf_max = MAX(knot_edns_get_payload(engine->resolver.opt_rr), + pktbuf_max); + } + + /* Create resolution task */ + struct qr_task *task = mm_alloc(&ctx->req.pool, sizeof(*task)); + if (!task) { + return NULL; + } + memset(task, 0, sizeof(*task)); /* avoid accidentally unintialized fields */ + + /* Create packet buffers for answer and subrequests */ + knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &ctx->req.pool); + if (!pktbuf) { + mm_free(&ctx->req.pool, task); + return NULL; + } + pktbuf->size = 0; + + task->ctx = ctx; + task->pktbuf = pktbuf; + array_init(task->waiting); + task->refs = 0; + assert(ctx->task == NULL); + ctx->task = task; + /* Make the primary reference to task. */ + qr_task_ref(task); + task->creation_time = kr_now(); + ctx->worker->stats.concurrent += 1; + return task; +} + +/* This is called when the task refcount is zero, free memory. */ +static void qr_task_free(struct qr_task *task) +{ + struct request_ctx *ctx = task->ctx; + + assert(ctx); + + struct worker_ctx *worker = ctx->worker; + + if (ctx->task == NULL) { + request_free(ctx); + } + + /* Update stats */ + worker->stats.concurrent -= 1; +} + +/*@ Register new qr_task within session. */ +static int qr_task_register(struct qr_task *task, struct session *session) +{ + assert(!session_flags(session)->outgoing && session_get_handle(session)->type == UV_TCP); + + session_tasklist_add(session, task); + + struct request_ctx *ctx = task->ctx; + assert(ctx && (ctx->source.session == NULL || ctx->source.session == session)); + ctx->source.session = session; + /* Soft-limit on parallel queries, there is no "slow down" RCODE + * that we could use to signalize to client, but we can stop reading, + * an in effect shrink TCP window size. To get more precise throttling, + * we would need to copy remainder of the unread buffer and reassemble + * when resuming reading. This is NYI. */ + if (session_tasklist_get_len(session) >= task->ctx->worker->tcp_pipeline_max && + !session_flags(session)->throttled && !session_flags(session)->closing) { + session_stop_read(session); + session_flags(session)->throttled = true; + } + + return 0; +} + +static void qr_task_complete(struct qr_task *task) +{ + struct request_ctx *ctx = task->ctx; + + /* Kill pending I/O requests */ + ioreq_kill_pending(task); + assert(task->waiting.len == 0); + assert(task->leading == false); + + struct session *s = ctx->source.session; + if (s) { + assert(!session_flags(s)->outgoing && session_waitinglist_is_empty(s)); + ctx->source.session = NULL; + session_tasklist_del(s, task); + } + + /* Release primary reference to task. */ + if (ctx->task == task) { + ctx->task = NULL; + qr_task_unref(task); + } +} + +/* This is called when we send subrequest / answer */ +int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status) +{ + + if (task->finished) { + assert(task->leading == false); + qr_task_complete(task); + } + + if (!handle || handle->type != UV_TCP) { + return status; + } + + struct session* s = handle->data; + assert(s); + if (status != 0) { + session_tasklist_del(s, task); + } + + if (session_flags(s)->outgoing || session_flags(s)->closing) { + return status; + } + + struct worker_ctx *worker = task->ctx->worker; + if (session_flags(s)->throttled && + session_tasklist_get_len(s) < worker->tcp_pipeline_max/2) { + /* Start reading again if the session is throttled and + * the number of outgoing requests is below watermark. */ + session_start_read(s); + session_flags(s)->throttled = false; + } + + return status; +} + +static void on_send(uv_udp_send_t *req, int status) +{ + struct qr_task *task = req->data; + uv_handle_t *h = (uv_handle_t *)req->handle; + qr_task_on_send(task, h, status); + qr_task_unref(task); + free(req); +} + +static void on_write(uv_write_t *req, int status) +{ + struct qr_task *task = req->data; + uv_handle_t *h = (uv_handle_t *)req->handle; + qr_task_on_send(task, h, status); + qr_task_unref(task); + free(req); +} + +static int qr_task_send(struct qr_task *task, struct session *session, + const struct sockaddr *addr, knot_pkt_t *pkt) +{ + if (!session) { + return qr_task_on_send(task, NULL, kr_error(EIO)); + } + + int ret = 0; + struct request_ctx *ctx = task->ctx; + + uv_handle_t *handle = session_get_handle(session); + assert(handle && handle->data == session); + const bool is_stream = handle->type == UV_TCP; + if (!is_stream && handle->type != UV_UDP) abort(); + + if (addr == NULL) { + addr = session_get_peer(session); + } + + if (pkt == NULL) { + pkt = worker_task_get_pktbuf(task); + } + + if (session_flags(session)->outgoing && handle->type == UV_TCP) { + size_t try_limit = session_tasklist_get_len(session) + 1; + uint16_t msg_id = knot_wire_get_id(pkt->wire); + size_t try_count = 0; + while (session_tasklist_find_msgid(session, msg_id) && + try_count <= try_limit) { + ++msg_id; + ++try_count; + } + if (try_count > try_limit) { + return kr_error(ENOENT); + } + worker_task_pkt_set_msgid(task, msg_id); + } + + uv_handle_t *ioreq = malloc(is_stream ? sizeof(uv_write_t) : sizeof(uv_udp_send_t)); + if (!ioreq) { + return qr_task_on_send(task, handle, kr_error(ENOMEM)); + } + + /* Pending ioreq on current task */ + qr_task_ref(task); + + struct worker_ctx *worker = ctx->worker; + /* Send using given protocol */ + assert(!session_flags(session)->closing); + if (session_flags(session)->has_tls) { + uv_write_t *write_req = (uv_write_t *)ioreq; + write_req->data = task; + ret = tls_write(write_req, handle, pkt, &on_write); + } else if (handle->type == UV_UDP) { + uv_udp_send_t *send_req = (uv_udp_send_t *)ioreq; + uv_buf_t buf = { (char *)pkt->wire, pkt->size }; + send_req->data = task; + ret = uv_udp_send(send_req, (uv_udp_t *)handle, &buf, 1, addr, &on_send); + } else if (handle->type == UV_TCP) { + uv_write_t *write_req = (uv_write_t *)ioreq; + /* We need to write message length in native byte order, + * but we don't have a convenient place to store those bytes. + * The problem is that all memory referenced from buf[] MUST retain + * its contents at least until on_write() is called, and I currently + * can't see any convenient place outside the `pkt` structure. + * So we use directly the *individual* bytes in pkt->size. + * The call to htonl() and the condition will probably be inlinable. */ + int lsbi, slsbi; /* (second) least significant byte index */ + if (htonl(1) == 1) { /* big endian */ + lsbi = sizeof(pkt->size) - 1; + slsbi = sizeof(pkt->size) - 2; + } else { + lsbi = 0; + slsbi = 1; + } + uv_buf_t buf[3] = { + { (char *)&pkt->size + slsbi, 1 }, + { (char *)&pkt->size + lsbi, 1 }, + { (char *)pkt->wire, pkt->size }, + }; + write_req->data = task; + ret = uv_write(write_req, (uv_stream_t *)handle, buf, 3, &on_write); + } else { + assert(false); + } + + if (ret == 0) { + session_touch(session); + if (session_flags(session)->outgoing) { + session_tasklist_add(session, task); + } + if (worker->too_many_open && + worker->stats.rconcurrent < + worker->rconcurrent_highwatermark - 10) { + worker->too_many_open = false; + } + } else { + free(ioreq); + qr_task_unref(task); + if (ret == UV_EMFILE) { + worker->too_many_open = true; + worker->rconcurrent_highwatermark = worker->stats.rconcurrent; + ret = kr_error(UV_EMFILE); + } + } + + /* Update statistics */ + if (session_flags(session)->outgoing && addr) { + if (session_flags(session)->has_tls) + worker->stats.tls += 1; + else if (handle->type == UV_UDP) + worker->stats.udp += 1; + else + worker->stats.tcp += 1; + + if (addr->sa_family == AF_INET6) + worker->stats.ipv6 += 1; + else if (addr->sa_family == AF_INET) + worker->stats.ipv4 += 1; + } + return ret; +} + +static struct kr_query *task_get_last_pending_query(struct qr_task *task) +{ + if (!task || task->ctx->req.rplan.pending.len == 0) { + return NULL; + } + + return array_tail(task->ctx->req.rplan.pending); +} + +static int session_tls_hs_cb(struct session *session, int status) +{ + assert(session_flags(session)->outgoing); + uv_handle_t *handle = session_get_handle(session); + uv_loop_t *loop = handle->loop; + struct worker_ctx *worker = loop->data; + struct sockaddr *peer = session_get_peer(session); + int deletion_res = worker_del_tcp_waiting(worker, peer); + int ret = kr_ok(); + + if (status) { + struct qr_task *task = session_waitinglist_get(session); + if (task) { + struct kr_qflags *options = &task->ctx->req.options; + unsigned score = options->FORWARD || options->STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD; + kr_nsrep_update_rtt(NULL, peer, score, + worker->engine->resolver.cache_rtt, + KR_NS_UPDATE_NORESET); + } +#ifndef NDEBUG + else { + /* Task isn't in the list of tasks + * waiting for connection to upstream. + * So that it MUST be unsuccessful rehandshake. + * Check it. */ + assert(deletion_res != 0); + const char *key = tcpsess_key(peer); + assert(key); + assert(map_contains(&worker->tcp_connected, key) != 0); + } +#endif + return ret; + } + + /* handshake was completed successfully */ + struct tls_client_ctx_t *tls_client_ctx = session_tls_get_client_ctx(session); + tls_client_param_t *tls_params = tls_client_ctx->params; + gnutls_session_t tls_session = tls_client_ctx->c.tls_session; + if (gnutls_session_is_resumed(tls_session) != 0) { + kr_log_verbose("[tls_client] TLS session has resumed\n"); + } else { + kr_log_verbose("[tls_client] TLS session has not resumed\n"); + /* session wasn't resumed, delete old session data ... */ + if (tls_params->session_data.data != NULL) { + gnutls_free(tls_params->session_data.data); + tls_params->session_data.data = NULL; + tls_params->session_data.size = 0; + } + /* ... and get the new session data */ + gnutls_datum_t tls_session_data = { NULL, 0 }; + ret = gnutls_session_get_data2(tls_session, &tls_session_data); + if (ret == 0) { + tls_params->session_data = tls_session_data; + } + } + + struct session *s = worker_find_tcp_connected(worker, peer); + ret = kr_ok(); + if (deletion_res == kr_ok()) { + /* peer was in the waiting list, add to the connected list. */ + if (s) { + /* Something went wrong, + * peer already is in the connected list. */ + ret = kr_error(EINVAL); + } else { + ret = worker_add_tcp_connected(worker, peer, session); + } + } else { + /* peer wasn't in the waiting list. + * It can be + * 1) either successful rehandshake; in this case peer + * must be already in the connected list. + * 2) or successful handshake with session, which was timeouted + * by on_tcp_connect_timeout(); after successful tcp connection; + * in this case peer isn't in the connected list. + **/ + if (!s || s != session) { + ret = kr_error(EINVAL); + } + } + if (ret == kr_ok()) { + while (!session_waitinglist_is_empty(session)) { + struct qr_task *t = session_waitinglist_get(session); + ret = qr_task_send(t, session, NULL, NULL); + if (ret != 0) { + break; + } + session_waitinglist_pop(session, true); + } + } else { + ret = kr_error(EINVAL); + } + + if (ret != kr_ok()) { + /* Something went wrong. + * Either addition to the list of connected sessions + * or write to upstream failed. */ + worker_del_tcp_connected(worker, peer); + session_waitinglist_finalize(session, KR_STATE_FAIL); + assert(session_tasklist_is_empty(session)); + session_close(session); + } else { + session_timer_stop(session); + session_timer_start(session, tcp_timeout_trigger, + MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY); + } + return kr_ok(); +} + +static int send_waiting(struct session *session) +{ + int ret = 0; + while (!session_waitinglist_is_empty(session)) { + struct qr_task *t = session_waitinglist_get(session); + ret = qr_task_send(t, session, NULL, NULL); + if (ret != 0) { + struct worker_ctx *worker = t->ctx->worker; + struct sockaddr *peer = session_get_peer(session); + session_waitinglist_finalize(session, KR_STATE_FAIL); + session_tasklist_finalize(session, KR_STATE_FAIL); + worker_del_tcp_connected(worker, peer); + session_close(session); + break; + } + session_waitinglist_pop(session, true); + } + return ret; +} + +static void on_connect(uv_connect_t *req, int status) +{ + struct worker_ctx *worker = the_worker; + assert(worker); + uv_stream_t *handle = req->handle; + struct session *session = handle->data; + struct sockaddr *peer = session_get_peer(session); + free(req); + + assert(session_flags(session)->outgoing); + + if (session_flags(session)->closing) { + worker_del_tcp_waiting(worker, peer); + assert(session_is_empty(session)); + return; + } + + /* Check if the connection is in the waiting list. + * If no, most likely this is timeouted connection + * which was removed from waiting list by + * on_tcp_connect_timeout() callback. */ + struct session *s = worker_find_tcp_waiting(worker, peer); + if (!s || s != session) { + /* session isn't on the waiting list. + * it's timeouted session. */ + if (VERBOSE_STATUS) { + const char *peer_str = kr_straddr(peer); + kr_log_verbose( "[wrkr]=> connected to '%s', but session " + "is already timeouted, close\n", + peer_str ? peer_str : ""); + } + assert(session_tasklist_is_empty(session)); + session_waitinglist_retry(session, false); + session_close(session); + return; + } + + s = worker_find_tcp_connected(worker, peer); + if (s) { + /* session already in the connected list. + * Something went wrong, it can be due to races when kresd has tried + * to reconnect to upstream after unsuccessful attempt. */ + if (VERBOSE_STATUS) { + const char *peer_str = kr_straddr(peer); + kr_log_verbose( "[wrkr]=> connected to '%s', but peer " + "is already connected, close\n", + peer_str ? peer_str : ""); + } + assert(session_tasklist_is_empty(session)); + session_waitinglist_retry(session, false); + session_close(session); + return; + } + + if (status != 0) { + if (VERBOSE_STATUS) { + const char *peer_str = kr_straddr(peer); + kr_log_verbose( "[wrkr]=> connection to '%s' failed (%s), flagged as 'bad'\n", + peer_str ? peer_str : "", uv_strerror(status)); + } + worker_del_tcp_waiting(worker, peer); + struct qr_task *task = session_waitinglist_get(session); + if (task && status != UV_ETIMEDOUT) { + /* Penalize upstream. + * In case of UV_ETIMEDOUT upstream has been + * already penalized in on_tcp_connect_timeout() */ + struct kr_qflags *options = &task->ctx->req.options; + unsigned score = options->FORWARD || options->STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD; + kr_nsrep_update_rtt(NULL, peer, score, + worker->engine->resolver.cache_rtt, + KR_NS_UPDATE_NORESET); + } + assert(session_tasklist_is_empty(session)); + session_waitinglist_retry(session, false); + session_close(session); + return; + } + + if (!session_flags(session)->has_tls) { + /* if there is a TLS, session still waiting for handshake, + * otherwise remove it from waiting list */ + if (worker_del_tcp_waiting(worker, peer) != 0) { + /* session isn't in list of waiting queries, * + * something gone wrong */ + session_waitinglist_finalize(session, KR_STATE_FAIL); + assert(session_tasklist_is_empty(session)); + session_close(session); + return; + } + } + + if (VERBOSE_STATUS) { + const char *peer_str = kr_straddr(peer); + kr_log_verbose( "[wrkr]=> connected to '%s'\n", peer_str ? peer_str : ""); + } + + session_flags(session)->connected = true; + session_start_read(session); + + int ret = kr_ok(); + if (session_flags(session)->has_tls) { + struct tls_client_ctx_t *tls_ctx = session_tls_get_client_ctx(session); + ret = tls_client_connect_start(tls_ctx, session, session_tls_hs_cb); + if (ret == kr_error(EAGAIN)) { + session_timer_stop(session); + session_timer_start(session, tcp_timeout_trigger, + MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY); + return; + } + } else { + worker_add_tcp_connected(worker, peer, session); + } + + ret = send_waiting(session); + if (ret != 0) { + return; + } + + session_timer_stop(session); + session_timer_start(session, tcp_timeout_trigger, + MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY); +} + +static void on_tcp_connect_timeout(uv_timer_t *timer) +{ + struct session *session = timer->data; + + uv_timer_stop(timer); + struct worker_ctx *worker = the_worker; + assert(worker); + + assert (session_tasklist_is_empty(session)); + + struct sockaddr *peer = session_get_peer(session); + worker_del_tcp_waiting(worker, peer); + + struct qr_task *task = session_waitinglist_get(session); + if (!task) { + /* Normally shouldn't happen. */ + const char *peer_str = kr_straddr(peer); + VERBOSE_MSG(NULL, "=> connection to '%s' failed (internal timeout), empty waitinglist\n", + peer_str ? peer_str : ""); + return; + } + + struct kr_query *qry = task_get_last_pending_query(task); + WITH_VERBOSE (qry) { + const char *peer_str = kr_straddr(peer); + VERBOSE_MSG(qry, "=> connection to '%s' failed (internal timeout)\n", + peer_str ? peer_str : ""); + } + + unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD; + kr_nsrep_update_rtt(NULL, peer, score, + worker->engine->resolver.cache_rtt, + KR_NS_UPDATE_NORESET); + + worker->stats.timeout += session_waitinglist_get_len(session); + session_waitinglist_retry(session, true); + assert (session_tasklist_is_empty(session)); + /* uv_cancel() doesn't support uv_connect_t request, + * so that we can't cancel it. + * There still exists possibility of successful connection + * for this request. + * So connection callback (on_connect()) must check + * if connection is in the list of waiting connection. + * If no, most likely this is timeouted connection even if + * it was successful. */ +} + +/* This is called when I/O timeouts */ +static void on_udp_timeout(uv_timer_t *timer) +{ + struct session *session = timer->data; + assert(session_get_handle(session)->data == session); + assert(session_tasklist_get_len(session) == 1); + assert(session_waitinglist_is_empty(session)); + + uv_timer_stop(timer); + + /* Penalize all tried nameservers with a timeout. */ + struct qr_task *task = session_tasklist_get_first(session); + struct worker_ctx *worker = task->ctx->worker; + if (task->leading && task->pending_count > 0) { + struct kr_query *qry = array_tail(task->ctx->req.rplan.pending); + struct sockaddr_in6 *addrlist = (struct sockaddr_in6 *)task->addrlist; + for (uint16_t i = 0; i < MIN(task->pending_count, task->addrlist_count); ++i) { + struct sockaddr *choice = (struct sockaddr *)(&addrlist[i]); + WITH_VERBOSE(qry) { + char *addr_str = kr_straddr(choice); + VERBOSE_MSG(qry, "=> server: '%s' flagged as 'bad'\n", addr_str ? addr_str : ""); + } + unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD; + kr_nsrep_update_rtt(&qry->ns, choice, score, + worker->engine->resolver.cache_rtt, + KR_NS_UPDATE_NORESET); + } + } + task->timeouts += 1; + worker->stats.timeout += 1; + qr_task_step(task, NULL, NULL); +} + +static uv_handle_t *retransmit(struct qr_task *task) +{ + uv_handle_t *ret = NULL; + if (task && task->addrlist && task->addrlist_count > 0) { + struct sockaddr_in6 *choice = &((struct sockaddr_in6 *)task->addrlist)[task->addrlist_turn]; + if (!choice) { + return ret; + } + if (task->pending_count >= MAX_PENDING) { + return ret; + } + /* Checkout answer before sending it */ + struct request_ctx *ctx = task->ctx; + if (kr_resolve_checkout(&ctx->req, NULL, (struct sockaddr *)choice, SOCK_DGRAM, task->pktbuf) != 0) { + return ret; + } + ret = ioreq_spawn(ctx->worker, SOCK_DGRAM, choice->sin6_family, false); + if (!ret) { + return ret; + } + struct sockaddr *addr = (struct sockaddr *)choice; + struct session *session = ret->data; + struct sockaddr *peer = session_get_peer(session); + assert (peer->sa_family == AF_UNSPEC && session_flags(session)->outgoing); + memcpy(peer, addr, kr_sockaddr_len(addr)); + if (qr_task_send(task, session, (struct sockaddr *)choice, + task->pktbuf) != 0) { + session_close(session); + ret = NULL; + } else { + task->pending[task->pending_count] = session; + task->pending_count += 1; + task->addrlist_turn = (task->addrlist_turn + 1) % + task->addrlist_count; /* Round robin */ + session_start_read(session); /* Start reading answer */ + } + } + return ret; +} + +static void on_retransmit(uv_timer_t *req) +{ + struct session *session = req->data; + assert(session_tasklist_get_len(session) == 1); + + uv_timer_stop(req); + struct qr_task *task = session_tasklist_get_first(session); + if (retransmit(task) == NULL) { + /* Not possible to spawn request, start timeout timer with remaining deadline. */ + struct kr_qflags *options = &task->ctx->req.options; + uint64_t timeout = options->FORWARD || options->STUB ? KR_NS_FWD_TIMEOUT / 2 : + KR_CONN_RTT_MAX - task->pending_count * KR_CONN_RETRY; + uv_timer_start(req, on_udp_timeout, timeout, 0); + } else { + uv_timer_start(req, on_retransmit, KR_CONN_RETRY, 0); + } +} + +static void subreq_finalize(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *pkt) +{ + if (!task || task->finished) { + return; + } + /* Close pending timer */ + ioreq_kill_pending(task); + /* Clear from outgoing table. */ + if (!task->leading) + return; + char key[SUBREQ_KEY_LEN]; + const int klen = subreq_key(key, task->pktbuf); + if (klen > 0) { + void *val_deleted; + int ret = trie_del(task->ctx->worker->subreq_out, key, klen, &val_deleted); + assert(ret == KNOT_EOK && val_deleted == task); (void)ret; + } + /* Notify waiting tasks. */ + struct kr_query *leader_qry = array_tail(task->ctx->req.rplan.pending); + for (size_t i = task->waiting.len; i > 0; i--) { + struct qr_task *follower = task->waiting.at[i - 1]; + /* Reuse MSGID and 0x20 secret */ + if (follower->ctx->req.rplan.pending.len > 0) { + struct kr_query *qry = array_tail(follower->ctx->req.rplan.pending); + qry->id = leader_qry->id; + qry->secret = leader_qry->secret; + leader_qry->secret = 0; /* Next will be already decoded */ + } + qr_task_step(follower, packet_source, pkt); + qr_task_unref(follower); + } + task->waiting.len = 0; + task->leading = false; +} + +static void subreq_lead(struct qr_task *task) +{ + assert(task); + char key[SUBREQ_KEY_LEN]; + const int klen = subreq_key(key, task->pktbuf); + if (klen < 0) + return; + struct qr_task **tvp = (struct qr_task **) + trie_get_ins(task->ctx->worker->subreq_out, key, klen); + if (unlikely(!tvp)) + return; /*ENOMEM*/ + if (unlikely(*tvp != NULL)) { + assert(false); + return; + } + *tvp = task; + task->leading = true; +} + +static bool subreq_enqueue(struct qr_task *task) +{ + assert(task); + char key[SUBREQ_KEY_LEN]; + const int klen = subreq_key(key, task->pktbuf); + if (klen < 0) + return false; + struct qr_task **leader = (struct qr_task **) + trie_get_try(task->ctx->worker->subreq_out, key, klen); + if (!leader /*ENOMEM*/ || !*leader) + return false; + /* Enqueue itself to leader for this subrequest. */ + int ret = array_push_mm((*leader)->waiting, task, + kr_memreserve, &(*leader)->ctx->req.pool); + if (unlikely(ret < 0)) /*ENOMEM*/ + return false; + qr_task_ref(task); + return true; +} + +static int qr_task_finalize(struct qr_task *task, int state) +{ + assert(task && task->leading == false); + if (task->finished) { + return 0; + } + struct request_ctx *ctx = task->ctx; + struct session *source_session = ctx->source.session; + kr_resolve_finish(&ctx->req, state); + + task->finished = true; + if (source_session == NULL) { + (void) qr_task_on_send(task, NULL, kr_error(EIO)); + return state == KR_STATE_DONE ? 0 : kr_error(EIO); + } + + /* Reference task as the callback handler can close it */ + qr_task_ref(task); + + /* Send back answer */ + assert(!session_flags(source_session)->closing); + assert(ctx->source.addr.ip.sa_family != AF_UNSPEC); + + int ret; + const uv_handle_t *src_handle = session_get_handle(source_session); + if (src_handle->type != UV_UDP && src_handle->type != UV_TCP) { + assert(false); + ret = kr_error(EINVAL); + } else if (src_handle->type == UV_UDP && ENABLE_SENDMMSG) { + int fd; + ret = uv_fileno(src_handle, &fd); + assert(!ret); + if (ret == 0) { + udp_queue_push(fd, &ctx->req, task); + } + } else { + ret = qr_task_send(task, source_session, &ctx->source.addr.ip, ctx->req.answer); + } + + if (ret != kr_ok()) { + (void) qr_task_on_send(task, NULL, kr_error(EIO)); + /* Since source session is erroneous detach all tasks. */ + while (!session_tasklist_is_empty(source_session)) { + struct qr_task *t = session_tasklist_del_first(source_session, false); + struct request_ctx *c = t->ctx; + assert(c->source.session == source_session); + c->source.session = NULL; + /* Don't finalize them as there can be other tasks + * waiting for answer to this particular task. + * (ie. task->leading is true) */ + worker_task_unref(t); + } + session_close(source_session); + } + + qr_task_unref(task); + + return state == KR_STATE_DONE ? 0 : kr_error(EIO); +} + +static int udp_task_step(struct qr_task *task, + const struct sockaddr *packet_source, knot_pkt_t *packet) +{ + struct request_ctx *ctx = task->ctx; + struct kr_request *req = &ctx->req; + + /* If there is already outgoing query, enqueue to it. */ + if (subreq_enqueue(task)) { + return kr_ok(); /* Will be notified when outgoing query finishes. */ + } + /* Start transmitting */ + uv_handle_t *handle = retransmit(task); + if (handle == NULL) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + /* Check current query NSLIST */ + struct kr_query *qry = array_tail(req->rplan.pending); + assert(qry != NULL); + /* Retransmit at default interval, or more frequently if the mean + * RTT of the server is better. If the server is glued, use default rate. */ + size_t timeout = qry->ns.score; + if (timeout > KR_NS_GLUED) { + /* We don't have information about variance in RTT, expect +10ms */ + timeout = MIN(qry->ns.score + 10, KR_CONN_RETRY); + } else { + timeout = KR_CONN_RETRY; + } + /* Announce and start subrequest. + * @note Only UDP can lead I/O as it doesn't touch 'task->pktbuf' for reassembly. + */ + subreq_lead(task); + struct session *session = handle->data; + assert(session_get_handle(session) == handle && (handle->type == UV_UDP)); + int ret = session_timer_start(session, on_retransmit, timeout, 0); + /* Start next step with timeout, fatal if can't start a timer. */ + if (ret != 0) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + return kr_ok(); +} + +static int tcp_task_waiting_connection(struct session *session, struct qr_task *task) +{ + assert(session_flags(session)->outgoing); + if (session_flags(session)->closing) { + /* Something went wrong. Better answer with KR_STATE_FAIL. + * TODO: normally should not happen, + * consider possibility to transform this into + * assert(!session_flags(session)->closing). */ + return kr_error(EINVAL); + } + /* Add task to the end of list of waiting tasks. + * It will be notified in on_connect() or qr_task_on_send(). */ + int ret = session_waitinglist_push(session, task); + if (ret < 0) { + return kr_error(EINVAL); + } + return kr_ok(); +} + +static int tcp_task_existing_connection(struct session *session, struct qr_task *task) +{ + assert(session_flags(session)->outgoing); + struct request_ctx *ctx = task->ctx; + struct worker_ctx *worker = ctx->worker; + + if (session_flags(session)->closing) { + /* Something went wrong. Better answer with KR_STATE_FAIL. + * TODO: normally should not happen, + * consider possibility to transform this into + * assert(!session_flags(session)->closing). */ + return kr_error(EINVAL); + } + + /* If there are any unsent queries, send it first. */ + int ret = send_waiting(session); + if (ret != 0) { + return kr_error(EINVAL); + } + + /* No unsent queries at that point. */ + if (session_tasklist_get_len(session) >= worker->tcp_pipeline_max) { + /* Too many outstanding queries, answer with SERFVAIL, */ + return kr_error(EINVAL); + } + + /* Send query to upstream. */ + ret = qr_task_send(task, session, NULL, NULL); + if (ret != 0) { + /* Error, finalize task with SERVFAIL and + * close connection to upstream. */ + session_tasklist_finalize(session, KR_STATE_FAIL); + worker_del_tcp_connected(worker, session_get_peer(session)); + session_close(session); + return kr_error(EINVAL); + } + + return kr_ok(); +} + +static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr *addr) +{ + struct request_ctx *ctx = task->ctx; + struct worker_ctx *worker = ctx->worker; + + /* Check if there must be TLS */ + struct tls_client_ctx_t *tls_ctx = NULL; + struct network *net = &worker->engine->net; + tls_client_param_t *entry = tls_client_param_get(net->tls_client_params, addr); + if (entry) { + /* Address is configured to be used with TLS. + * We need to allocate auxiliary data structure. */ + tls_ctx = tls_client_ctx_new(entry, worker); + if (!tls_ctx) { + return kr_error(EINVAL); + } + } + + uv_connect_t *conn = malloc(sizeof(uv_connect_t)); + if (!conn) { + tls_client_ctx_free(tls_ctx); + return kr_error(EINVAL); + } + bool has_tls = (tls_ctx != NULL); + uv_handle_t *client = ioreq_spawn(worker, SOCK_STREAM, addr->sa_family, has_tls); + if (!client) { + tls_client_ctx_free(tls_ctx); + free(conn); + return kr_error(EINVAL); + } + struct session *session = client->data; + assert(session_flags(session)->has_tls == has_tls); + if (has_tls) { + tls_client_ctx_set_session(tls_ctx, session); + session_tls_set_client_ctx(session, tls_ctx); + } + + /* Add address to the waiting list. + * Now it "is waiting to be connected to." */ + int ret = worker_add_tcp_waiting(worker, addr, session); + if (ret < 0) { + free(conn); + session_close(session); + return kr_error(EINVAL); + } + + conn->data = session; + /* Store peer address for the session. */ + struct sockaddr *peer = session_get_peer(session); + memcpy(peer, addr, kr_sockaddr_len(addr)); + + /* Start watchdog to catch eventual connection timeout. */ + ret = session_timer_start(session, on_tcp_connect_timeout, + KR_CONN_RTT_MAX, 0); + if (ret != 0) { + worker_del_tcp_waiting(worker, addr); + free(conn); + session_close(session); + return kr_error(EINVAL); + } + + struct kr_query *qry = task_get_last_pending_query(task); + WITH_VERBOSE (qry) { + const char *peer_str = kr_straddr(peer); + VERBOSE_MSG(qry, "=> connecting to: '%s'\n", peer_str ? peer_str : ""); + } + + /* Start connection process to upstream. */ + ret = uv_tcp_connect(conn, (uv_tcp_t *)client, addr , on_connect); + if (ret != 0) { + session_timer_stop(session); + worker_del_tcp_waiting(worker, addr); + free(conn); + session_close(session); + unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD; + kr_nsrep_update_rtt(NULL, peer, score, + worker->engine->resolver.cache_rtt, + KR_NS_UPDATE_NORESET); + WITH_VERBOSE (qry) { + const char *peer_str = kr_straddr(peer); + kr_log_verbose( "[wrkr]=> connect to '%s' failed (%s), flagged as 'bad'\n", + peer_str ? peer_str : "", uv_strerror(ret)); + } + return kr_error(EAGAIN); + } + + /* Add task to the end of list of waiting tasks. + * Will be notified either in on_connect() or in qr_task_on_send(). */ + ret = session_waitinglist_push(session, task); + if (ret < 0) { + session_timer_stop(session); + worker_del_tcp_waiting(worker, addr); + free(conn); + session_close(session); + return kr_error(EINVAL); + } + + return kr_ok(); +} + +static int tcp_task_step(struct qr_task *task, + const struct sockaddr *packet_source, knot_pkt_t *packet) +{ + assert(task->pending_count == 0); + + /* target */ + const struct sockaddr *addr = task->addrlist; + if (addr->sa_family == AF_UNSPEC) { + /* Target isn't defined. Finalize task with SERVFAIL. + * Although task->pending_count is zero, there are can be followers, + * so we need to call subreq_finalize() to handle them properly. */ + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + /* Checkout task before connecting */ + struct request_ctx *ctx = task->ctx; + if (kr_resolve_checkout(&ctx->req, NULL, (struct sockaddr *)addr, + SOCK_STREAM, task->pktbuf) != 0) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + int ret; + struct session* session = NULL; + if ((session = worker_find_tcp_waiting(ctx->worker, addr)) != NULL) { + /* Connection is in the list of waiting connections. + * It means that connection establishing is coming right now. */ + ret = tcp_task_waiting_connection(session, task); + } else if ((session = worker_find_tcp_connected(ctx->worker, addr)) != NULL) { + /* Connection has been already established. */ + ret = tcp_task_existing_connection(session, task); + } else { + /* Make connection. */ + ret = tcp_task_make_connection(task, addr); + } + + if (ret != kr_ok()) { + subreq_finalize(task, addr, packet); + if (ret == kr_error(EAGAIN)) { + ret = qr_task_step(task, addr, NULL); + } else { + ret = qr_task_finalize(task, KR_STATE_FAIL); + } + } + + return ret; +} + +static int qr_task_step(struct qr_task *task, + const struct sockaddr *packet_source, knot_pkt_t *packet) +{ + /* No more steps after we're finished. */ + if (!task || task->finished) { + return kr_error(ESTALE); + } + + /* Close pending I/O requests */ + subreq_finalize(task, packet_source, packet); + if ((kr_now() - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) { + return qr_task_finalize(task, KR_STATE_FAIL); + } + + /* Consume input and produce next query */ + struct request_ctx *ctx = task->ctx; + assert(ctx); + struct kr_request *req = &ctx->req; + struct worker_ctx *worker = ctx->worker; + int sock_type = -1; + task->addrlist = NULL; + task->addrlist_count = 0; + task->addrlist_turn = 0; + + if (worker->too_many_open) { + /* */ + struct kr_rplan *rplan = &req->rplan; + if (worker->stats.rconcurrent < + worker->rconcurrent_highwatermark - 10) { + worker->too_many_open = false; + } else { + if (packet && kr_rplan_empty(rplan)) { + /* new query; TODO - make this detection more obvious */ + kr_resolve_consume(req, packet_source, packet); + } + return qr_task_finalize(task, KR_STATE_FAIL); + } + } + + int state = kr_resolve_consume(req, packet_source, packet); + while (state == KR_STATE_PRODUCE) { + state = kr_resolve_produce(req, &task->addrlist, + &sock_type, task->pktbuf); + if (unlikely(++task->iter_count > KR_ITER_LIMIT || + task->timeouts >= KR_TIMEOUT_LIMIT)) { + return qr_task_finalize(task, KR_STATE_FAIL); + } + } + + /* We're done, no more iterations needed */ + if (state & (KR_STATE_DONE|KR_STATE_FAIL)) { + return qr_task_finalize(task, state); + } else if (!task->addrlist || sock_type < 0) { + return qr_task_step(task, NULL, NULL); + } + + /* Count available address choices */ + struct sockaddr_in6 *choice = (struct sockaddr_in6 *)task->addrlist; + for (size_t i = 0; i < KR_NSREP_MAXADDR && choice->sin6_family != AF_UNSPEC; ++i) { + task->addrlist_count += 1; + choice += 1; + } + + /* Upgrade to TLS if the upstream address is configured as DoT capable. */ + if (task->addrlist_count > 0 && kr_inaddr_port(task->addrlist) == KR_DNS_PORT) { + /* TODO if there are multiple addresses (task->addrlist_count > 1) + * check all of them. */ + struct network *net = &worker->engine->net; + /* task->addrlist has to contain TLS port before tls_client_param_get() call */ + kr_inaddr_set_port(task->addrlist, KR_DNS_TLS_PORT); + tls_client_param_t *tls_entry = + tls_client_param_get(net->tls_client_params, task->addrlist); + if (tls_entry) { + packet_source = NULL; + sock_type = SOCK_STREAM; + /* TODO in this case in tcp_task_make_connection() will be performed + * redundant map_get() call. */ + } else { + /* The function is fairly cheap, so we just change there and back. */ + kr_inaddr_set_port(task->addrlist, KR_DNS_PORT); + } + } + + int ret = 0; + if (sock_type == SOCK_DGRAM) { + /* Start fast retransmit with UDP. */ + ret = udp_task_step(task, packet_source, packet); + } else { + /* TCP. Connect to upstream or send the query if connection already exists. */ + assert (sock_type == SOCK_STREAM); + ret = tcp_task_step(task, packet_source, packet); + } + return ret; +} + +static int parse_packet(knot_pkt_t *query) +{ + if (!query){ + return kr_error(EINVAL); + } + + /* Parse query packet. */ + int ret = knot_pkt_parse(query, 0); + if (ret == KNOT_ETRAIL) { + /* Extra data after message end. */ + ret = kr_error(EMSGSIZE); + } else if (ret != KNOT_EOK) { + /* Malformed query. */ + ret = kr_error(EPROTO); + } else { + ret = kr_ok(); + } + + return ret; +} + +int worker_submit(struct session *session, const struct sockaddr *peer, knot_pkt_t *query) +{ + if (!session) { + assert(false); + return kr_error(EINVAL); + } + + uv_handle_t *handle = session_get_handle(session); + bool OK = handle && handle->loop->data; + if (!OK) { + assert(false); + return kr_error(EINVAL); + } + + struct worker_ctx *worker = handle->loop->data; + + /* Parse packet */ + int ret = parse_packet(query); + + const bool is_query = (knot_wire_get_qr(query->wire) == 0); + const bool is_outgoing = session_flags(session)->outgoing; + /* Ignore badly formed queries. */ + if (!query || + (ret != kr_ok() && ret != kr_error(EMSGSIZE)) || + (is_query == is_outgoing)) { + if (query && !is_outgoing) worker->stats.dropped += 1; + return kr_error(EILSEQ); + } + + /* Start new task on listening sockets, + * or resume if this is subrequest */ + struct qr_task *task = NULL; + const struct sockaddr *addr = NULL; + if (!is_outgoing) { /* request from a client */ + struct request_ctx *ctx = request_create(worker, session, peer, + knot_wire_get_id(query->wire)); + if (!ctx) { + return kr_error(ENOMEM); + } + + ret = request_start(ctx, query); + if (ret != 0) { + request_free(ctx); + return kr_error(ENOMEM); + } + + task = qr_task_create(ctx); + if (!task) { + request_free(ctx); + return kr_error(ENOMEM); + } + + if (handle->type == UV_TCP && qr_task_register(task, session)) { + return kr_error(ENOMEM); + } + } else if (query) { /* response from upstream */ + const uint16_t id = knot_wire_get_id(query->wire); + task = session_tasklist_del_msgid(session, id); + if (task == NULL) { + VERBOSE_MSG(NULL, "=> ignoring packet with mismatching ID %d\n", + (int)id); + return kr_error(ENOENT); + } + assert(!session_flags(session)->closing); + addr = peer; + } + assert(uv_is_closing(session_get_handle(session)) == false); + + /* Packet was successfully parsed. + * Task was created (found). */ + session_touch(session); + /* Consume input and produce next message */ + return qr_task_step(task, addr, query); +} + +static int map_add_tcp_session(map_t *map, const struct sockaddr* addr, + struct session *session) +{ + assert(map && addr); + const char *key = tcpsess_key(addr); + assert(key); + assert(map_contains(map, key) == 0); + int ret = map_set(map, key, session); + return ret ? kr_error(EINVAL) : kr_ok(); +} + +static int map_del_tcp_session(map_t *map, const struct sockaddr* addr) +{ + assert(map && addr); + const char *key = tcpsess_key(addr); + assert(key); + int ret = map_del(map, key); + return ret ? kr_error(ENOENT) : kr_ok(); +} + +static struct session* map_find_tcp_session(map_t *map, + const struct sockaddr *addr) +{ + assert(map && addr); + const char *key = tcpsess_key(addr); + assert(key); + struct session* ret = map_get(map, key); + return ret; +} + +int worker_add_tcp_connected(struct worker_ctx *worker, + const struct sockaddr* addr, + struct session *session) +{ +#ifndef NDEBUG + assert(addr); + const char *key = tcpsess_key(addr); + assert(key); + assert(map_contains(&worker->tcp_connected, key) == 0); +#endif + return map_add_tcp_session(&worker->tcp_connected, addr, session); +} + +int worker_del_tcp_connected(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + assert(addr && tcpsess_key(addr)); + return map_del_tcp_session(&worker->tcp_connected, addr); +} + +static struct session* worker_find_tcp_connected(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + return map_find_tcp_session(&worker->tcp_connected, addr); +} + +static int worker_add_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr, + struct session *session) +{ +#ifndef NDEBUG + assert(addr); + const char *key = tcpsess_key(addr); + assert(key); + assert(map_contains(&worker->tcp_waiting, key) == 0); +#endif + return map_add_tcp_session(&worker->tcp_waiting, addr, session); +} + +int worker_del_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + assert(addr && tcpsess_key(addr)); + return map_del_tcp_session(&worker->tcp_waiting, addr); +} + +static struct session* worker_find_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + return map_find_tcp_session(&worker->tcp_waiting, addr); +} + +int worker_end_tcp(struct session *session) +{ + if (!session) { + return kr_error(EINVAL); + } + + session_timer_stop(session); + + uv_handle_t *handle = session_get_handle(session); + struct worker_ctx *worker = handle->loop->data; + struct sockaddr *peer = session_get_peer(session); + + worker_del_tcp_waiting(worker, peer); + worker_del_tcp_connected(worker, peer); + session_flags(session)->connected = false; + + struct tls_client_ctx_t *tls_client_ctx = session_tls_get_client_ctx(session); + if (tls_client_ctx) { + /* Avoid gnutls_bye() call */ + tls_set_hs_state(&tls_client_ctx->c, TLS_HS_NOT_STARTED); + } + + struct tls_ctx_t *tls_ctx = session_tls_get_server_ctx(session); + if (tls_ctx) { + /* Avoid gnutls_bye() call */ + tls_set_hs_state(&tls_ctx->c, TLS_HS_NOT_STARTED); + } + + while (!session_waitinglist_is_empty(session)) { + struct qr_task *task = session_waitinglist_pop(session, false); + assert(task->refs > 1); + session_tasklist_del(session, task); + if (session_flags(session)->outgoing) { + if (task->ctx->req.options.FORWARD) { + /* We are in TCP_FORWARD mode. + * To prevent failing at kr_resolve_consume() + * qry.flags.TCP must be cleared. + * TODO - refactoring is needed. */ + struct kr_request *req = &task->ctx->req; + struct kr_rplan *rplan = &req->rplan; + struct kr_query *qry = array_tail(rplan->pending); + qry->flags.TCP = false; + } + qr_task_step(task, NULL, NULL); + } else { + assert(task->ctx->source.session == session); + task->ctx->source.session = NULL; + } + worker_task_unref(task); + } + while (!session_tasklist_is_empty(session)) { + struct qr_task *task = session_tasklist_del_first(session, false); + if (session_flags(session)->outgoing) { + if (task->ctx->req.options.FORWARD) { + struct kr_request *req = &task->ctx->req; + struct kr_rplan *rplan = &req->rplan; + struct kr_query *qry = array_tail(rplan->pending); + qry->flags.TCP = false; + } + qr_task_step(task, NULL, NULL); + } else { + assert(task->ctx->source.session == session); + task->ctx->source.session = NULL; + } + worker_task_unref(task); + } + session_close(session); + return kr_ok(); +} + +knot_pkt_t * worker_resolve_mk_pkt(const char *qname_str, uint16_t qtype, uint16_t qclass, + const struct kr_qflags *options) +{ + uint8_t qname[KNOT_DNAME_MAXLEN]; + if (!knot_dname_from_str(qname, qname_str, sizeof(qname))) + return NULL; + knot_pkt_t *pkt = knot_pkt_new(NULL, KNOT_EDNS_MAX_UDP_PAYLOAD, NULL); + if (!pkt) + return NULL; + knot_pkt_put_question(pkt, qname, qclass, qtype); + knot_wire_set_rd(pkt->wire); + knot_wire_set_ad(pkt->wire); + + /* Add OPT RR */ + pkt->opt_rr = knot_rrset_copy(the_worker->engine->resolver.opt_rr, NULL); + if (!pkt->opt_rr) { + knot_pkt_free(pkt); + return NULL; + } + if (options->DNSSEC_WANT) { + knot_edns_set_do(pkt->opt_rr); + } + if (options->DNSSEC_CD) { + knot_wire_set_cd(pkt->wire); + } + + return pkt; +} + +struct qr_task *worker_resolve_start(knot_pkt_t *query, struct kr_qflags options) +{ + struct worker_ctx *worker = the_worker; + if (!worker || !query) { + assert(!EINVAL); + return NULL; + } + + + struct request_ctx *ctx = request_create(worker, NULL, NULL, worker->next_request_uid); + if (!ctx) { + return NULL; + } + + /* Create task */ + struct qr_task *task = qr_task_create(ctx); + if (!task) { + request_free(ctx); + return NULL; + } + + /* Start task */ + int ret = request_start(ctx, query); + if (ret != 0) { + /* task is attached to request context, + * so dereference (and deallocate) it first */ + ctx->task = NULL; + qr_task_unref(task); + request_free(ctx); + return NULL; + } + + worker->next_request_uid += 1; + if (worker->next_request_uid == 0) { + worker->next_request_uid = UINT16_MAX + 1; + } + + /* Set options late, as qr_task_start() -> kr_resolve_begin() rewrite it. */ + kr_qflags_set(&task->ctx->req.options, options); + return task; +} + +int worker_resolve_exec(struct qr_task *task, knot_pkt_t *query) +{ + if (!task) { + return kr_error(EINVAL); + } + return qr_task_step(task, NULL, query); +} + +int worker_task_numrefs(const struct qr_task *task) +{ + return task->refs; +} + +struct kr_request *worker_task_request(struct qr_task *task) +{ + if (!task || !task->ctx) { + return NULL; + } + + return &task->ctx->req; +} + +int worker_task_finalize(struct qr_task *task, int state) +{ + return qr_task_finalize(task, state); +} + + int worker_task_step(struct qr_task *task, const struct sockaddr *packet_source, + knot_pkt_t *packet) + { + return qr_task_step(task, packet_source, packet); + } + +void worker_task_complete(struct qr_task *task) +{ + qr_task_complete(task); +} + +void worker_task_ref(struct qr_task *task) +{ + qr_task_ref(task); +} + +void worker_task_unref(struct qr_task *task) +{ + qr_task_unref(task); +} + +void worker_task_timeout_inc(struct qr_task *task) +{ + task->timeouts += 1; +} + +knot_pkt_t *worker_task_get_pktbuf(const struct qr_task *task) +{ + return task->pktbuf; +} + +struct request_ctx *worker_task_get_request(struct qr_task *task) +{ + return task->ctx; +} + +struct session *worker_request_get_source_session(struct request_ctx *ctx) +{ + return ctx->source.session; +} + +void worker_request_set_source_session(struct request_ctx *ctx, struct session *session) +{ + ctx->source.session = session; +} + +uint16_t worker_task_pkt_get_msgid(struct qr_task *task) +{ + knot_pkt_t *pktbuf = worker_task_get_pktbuf(task); + uint16_t msg_id = knot_wire_get_id(pktbuf->wire); + return msg_id; +} + +void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid) +{ + knot_pkt_t *pktbuf = worker_task_get_pktbuf(task); + knot_wire_set_id(pktbuf->wire, msgid); + struct kr_query *q = task_get_last_pending_query(task); + q->id = msgid; +} + +uint64_t worker_task_creation_time(struct qr_task *task) +{ + return task->creation_time; +} + +void worker_task_subreq_finalize(struct qr_task *task) +{ + subreq_finalize(task, NULL, NULL); +} + +bool worker_task_finished(struct qr_task *task) +{ + return task->finished; +} + +/** Reserve worker buffers. We assume worker's been zeroed. */ +static int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen) +{ + worker->tcp_connected = map_make(NULL); + worker->tcp_waiting = map_make(NULL); + worker->subreq_out = trie_create(NULL); + + array_init(worker->pool_mp); + if (array_reserve(worker->pool_mp, ring_maxlen)) { + return kr_error(ENOMEM); + } + + worker->pkt_pool.ctx = mp_new (4 * sizeof(knot_pkt_t)); + worker->pkt_pool.alloc = (knot_mm_alloc_t) mp_alloc; + + return kr_ok(); +} + +static inline void reclaim_mp_freelist(mp_freelist_t *list) +{ + for (unsigned i = 0; i < list->len; ++i) { + struct mempool *e = list->at[i]; + kr_asan_unpoison(e, sizeof(*e)); + mp_delete(e); + } + array_clear(*list); +} + +void worker_deinit(void) +{ + struct worker_ctx *worker = the_worker; + assert(worker); + if (worker->z_import != NULL) { + zi_free(worker->z_import); + worker->z_import = NULL; + } + map_clear(&worker->tcp_connected); + map_clear(&worker->tcp_waiting); + trie_free(worker->subreq_out); + worker->subreq_out = NULL; + + reclaim_mp_freelist(&worker->pool_mp); + mp_delete(worker->pkt_pool.ctx); + worker->pkt_pool.ctx = NULL; + + the_worker = NULL; +} + +int worker_init(struct engine *engine, int worker_id, int worker_count) +{ + assert(engine && engine->L); + assert(the_worker == NULL); + kr_bindings_register(engine->L); + + /* Create main worker. */ + struct worker_ctx *worker = &the_worker_value; + memset(worker, 0, sizeof(*worker)); + worker->engine = engine; + + uv_loop_t *loop = uv_default_loop(); + worker->loop = loop; + + worker->id = worker_id; + worker->count = worker_count; + + /* Register table for worker per-request variables */ + lua_newtable(engine->L); + lua_setfield(engine->L, -2, "vars"); + lua_getfield(engine->L, -1, "vars"); + worker->vars_table_ref = luaL_ref(engine->L, LUA_REGISTRYINDEX); + lua_pop(engine->L, 1); + + worker->tcp_pipeline_max = MAX_PIPELINED; + worker->out_addr4.sin_family = AF_UNSPEC; + worker->out_addr6.sin6_family = AF_UNSPEC; + + int ret = worker_reserve(worker, MP_FREELIST_SIZE); + if (ret) return ret; + worker->next_request_uid = UINT16_MAX + 1; + + /* Set some worker.* fields in Lua */ + lua_getglobal(engine->L, "worker"); + lua_pushnumber(engine->L, worker_id); + lua_setfield(engine->L, -2, "id"); + lua_pushnumber(engine->L, getpid()); + lua_setfield(engine->L, -2, "pid"); + lua_pushnumber(engine->L, worker_count); + lua_setfield(engine->L, -2, "count"); + + the_worker = worker; + loop->data = the_worker; + /* ^^^^ This shouldn't be used anymore, but it's hard to be 100% sure. */ + return kr_ok(); +} + +#undef VERBOSE_MSG diff --git a/utils/watcher/worker.h b/utils/watcher/worker.h new file mode 100644 index 000000000..5ff9f8ea9 --- /dev/null +++ b/utils/watcher/worker.h @@ -0,0 +1,190 @@ +/* Copyright (C) 2014-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include "engine.h" +#include "lib/generic/array.h" +#include "lib/generic/map.h" + + +/** Query resolution task (opaque). */ +struct qr_task; +/** Worker state (opaque). */ +struct worker_ctx; +/** Transport session (opaque). */ +struct session; +/** Zone import context (opaque). */ +struct zone_import_ctx; + +/** Pointer to the singleton worker. NULL if not initialized. */ +KR_EXPORT extern struct worker_ctx *the_worker; + +/** Create and initialize the worker. + * \return error code (ENOMEM) */ +int worker_init(struct engine *engine, int worker_id, int worker_count); + +/** Destroy the worker (free memory). */ +void worker_deinit(void); + +/** + * Process an incoming packet (query from a client or answer from upstream). + * + * @param session session the packet came from + * @param peer address the packet came from + * @param query the packet, or NULL on an error from the transport layer + * @return 0 or an error code + */ +int worker_submit(struct session *session, const struct sockaddr *peer, knot_pkt_t *query); + +/** + * End current DNS/TCP session, this disassociates pending tasks from this session + * which may be freely closed afterwards. + */ +int worker_end_tcp(struct session *session); + +/** + * Create a packet suitable for worker_resolve_start(). All in malloc() memory. + */ +KR_EXPORT knot_pkt_t * +worker_resolve_mk_pkt(const char *qname_str, uint16_t qtype, uint16_t qclass, + const struct kr_qflags *options); + +/** + * Start query resolution with given query. + * + * @return task or NULL + */ +KR_EXPORT struct qr_task * +worker_resolve_start(knot_pkt_t *query, struct kr_qflags options); + +/** + * Execute a request with given query. + * It expects task to be created with \fn worker_resolve_start. + * + * @return 0 or an error code + */ +KR_EXPORT int worker_resolve_exec(struct qr_task *task, knot_pkt_t *query); + +/** @return struct kr_request associated with opaque task */ +struct kr_request *worker_task_request(struct qr_task *task); + +int worker_task_step(struct qr_task *task, const struct sockaddr *packet_source, + knot_pkt_t *packet); + +int worker_task_numrefs(const struct qr_task *task); + +/** Finalize given task */ +int worker_task_finalize(struct qr_task *task, int state); + +void worker_task_complete(struct qr_task *task); + +void worker_task_ref(struct qr_task *task); + +void worker_task_unref(struct qr_task *task); + +void worker_task_timeout_inc(struct qr_task *task); + +int worker_add_tcp_connected(struct worker_ctx *worker, + const struct sockaddr *addr, + struct session *session); +int worker_del_tcp_connected(struct worker_ctx *worker, + const struct sockaddr *addr); +int worker_del_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr); +knot_pkt_t *worker_task_get_pktbuf(const struct qr_task *task); + +struct request_ctx *worker_task_get_request(struct qr_task *task); + +struct session *worker_request_get_source_session(struct request_ctx *); + +void worker_request_set_source_session(struct request_ctx *, struct session *session); + +uint16_t worker_task_pkt_get_msgid(struct qr_task *task); +void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid); +uint64_t worker_task_creation_time(struct qr_task *task); +void worker_task_subreq_finalize(struct qr_task *task); +bool worker_task_finished(struct qr_task *task); + +/** To be called after sending a DNS message. It mainly deals with cleanups. */ +int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status); + +/** Various worker statistics. Sync with wrk_stats() */ +struct worker_stats { + size_t queries; /**< Total number of requests (from clients and internal ones). */ + size_t concurrent; /**< The number of requests currently in processing. */ + size_t rconcurrent; /*< TODO: remove? I see no meaningful difference from .concurrent. */ + size_t dropped; /**< The number of requests dropped due to being badly formed. See #471. */ + + size_t timeout; /**< Number of outbound queries that timed out. */ + size_t udp; /**< Number of outbound queries over UDP. */ + size_t tcp; /**< Number of outbound queries over TCP (excluding TLS). */ + size_t tls; /**< Number of outbound queries over TLS. */ + size_t ipv4; /**< Number of outbound queries over IPv4.*/ + size_t ipv6; /**< Number of outbound queries over IPv6. */ +}; + +/** @cond internal */ + +/** Number of request within timeout window. */ +#define MAX_PENDING KR_NSREP_MAXADDR + +/** Maximum response time from TCP upstream, milliseconds */ +#define MAX_TCP_INACTIVITY (KR_RESOLVE_TIME_LIMIT + KR_CONN_RTT_MAX) + +#ifndef RECVMMSG_BATCH /* see check_bufsize() */ +#define RECVMMSG_BATCH 1 +#endif + +/** Freelist of available mempools. */ +typedef array_t(struct mempool *) mp_freelist_t; + +/** List of query resolution tasks. */ +typedef array_t(struct qr_task *) qr_tasklist_t; + +/** \details Worker state is meant to persist during the whole life of daemon. */ +struct worker_ctx { + struct engine *engine; + uv_loop_t *loop; + int id; + int count; + int vars_table_ref; + unsigned tcp_pipeline_max; + + /** Addresses to bind for outgoing connections or AF_UNSPEC. */ + struct sockaddr_in out_addr4; + struct sockaddr_in6 out_addr6; + + uint8_t wire_buf[RECVMMSG_BATCH * KNOT_WIRE_MAX_PKTSIZE]; + + struct worker_stats stats; + + struct zone_import_ctx* z_import; + bool too_many_open; + size_t rconcurrent_highwatermark; + /** List of active outbound TCP sessions */ + map_t tcp_connected; + /** List of outbound TCP sessions waiting to be accepted */ + map_t tcp_waiting; + /** Subrequest leaders (struct qr_task*), indexed by qname+qtype+qclass. */ + trie_t *subreq_out; + mp_freelist_t pool_mp; + knot_mm_t pkt_pool; + unsigned int next_request_uid; +}; + +/** @endcond */ + diff --git a/utils/watcher/zimport.c b/utils/watcher/zimport.c new file mode 100644 index 000000000..89abff6fd --- /dev/null +++ b/utils/watcher/zimport.c @@ -0,0 +1,821 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +/* Module is intended to import resource records from file into resolver's cache. + * File supposed to be a standard DNS zone file + * which contains text representations of resource records. + * For now only root zone import is supported. + * + * Import process consists of two stages. + * 1) Zone file parsing. + * 2) Import of parsed entries into the cache. + * + * These stages are implemented as two separate functions + * (zi_zone_import and zi_zone_process) which runs sequentially with the + * pause between them. This is done because resolver is a single-threaded + * application, so it can't process user's requests during the whole import + * process. Separation into two stages allows to reduce the + * continuous time interval when resolver can't serve user requests. + * Since root zone isn't large it is imported as single + * chunk. If it would be considered as necessary, import stage can be + * split into shorter stages. + * + * zi_zone_import() uses libzscanner to parse zone file. + * Parsed records are stored to internal storage from where they are imported to + * cache during the second stage. + * + * zi_zone_process() imports parsed resource records to cache. + * It imports rrset by creating request that will never be sent to upstream. + * After request creation resolver creates pseudo-answer which must contain + * all necessary data for validation. Then resolver process answer as if he had + * been received from network. + */ + +#include /* PRIu64 */ +#include +#include +#include +#include +#include + +#include "lib/utils.h" +#include "lib/dnssec/ta.h" +#include "worker.h" +#include "zimport.h" +#include "lib/generic/map.h" +#include "lib/generic/array.h" + +#define VERBOSE_MSG(qry, ...) QRVERBOSE(qry, "zimport", __VA_ARGS__) + +/* Pause between parse and import stages, milliseconds. + * See comment in zi_zone_import() */ +#define ZONE_IMPORT_PAUSE 100 + +typedef array_t(knot_rrset_t *) qr_rrsetlist_t; + +struct zone_import_ctx { + struct worker_ctx *worker; + bool started; + knot_dname_t *origin; + knot_rrset_t *ta; + knot_rrset_t *key; + uint64_t start_timestamp; + size_t rrset_idx; + uv_timer_t timer; + map_t rrset_indexed; + qr_rrsetlist_t rrset_sorted; + knot_mm_t pool; + zi_callback cb; + void *cb_param; +}; + +typedef struct zone_import_ctx zone_import_ctx_t; + +static int RRSET_IS_ALREADY_IMPORTED = 1; + +/** @internal Allocate zone import context. + * @return pointer to zone import context or NULL. */ +static zone_import_ctx_t *zi_ctx_alloc() +{ + return (zone_import_ctx_t *)malloc(sizeof(zone_import_ctx_t)); +} + +/** @internal Free zone import context. */ +static void zi_ctx_free(zone_import_ctx_t *z_import) +{ + if (z_import != NULL) { + free(z_import); + } +} + +/** @internal Reset all fields in the zone import context to their default values. + * Flushes memory pool, but doesn't reallocate memory pool buffer. + * Doesn't affect timer handle, pointers to callback and callback parameter. + * @return 0 if success; -1 if failed. */ +static int zi_reset(struct zone_import_ctx *z_import, size_t rrset_sorted_list_size) +{ + mp_flush(z_import->pool.ctx); + + z_import->started = false; + z_import->start_timestamp = 0; + z_import->rrset_idx = 0; + z_import->pool.alloc = (knot_mm_alloc_t) mp_alloc; + z_import->rrset_indexed = map_make(&z_import->pool); + + array_init(z_import->rrset_sorted); + + int ret = 0; + if (rrset_sorted_list_size) { + ret = array_reserve_mm(z_import->rrset_sorted, rrset_sorted_list_size, + kr_memreserve, &z_import->pool); + } + + return ret; +} + +/** @internal Close callback for timer handle. + * @note Actually frees zone import context. */ +static void on_timer_close(uv_handle_t *handle) +{ + zone_import_ctx_t *z_import = (zone_import_ctx_t *)handle->data; + if (z_import != NULL) { + zi_ctx_free(z_import); + } +} + +zone_import_ctx_t *zi_allocate(struct worker_ctx *worker, + zi_callback cb, void *param) +{ + if (worker->loop == NULL) { + return NULL; + } + zone_import_ctx_t *z_import = zi_ctx_alloc(); + if (!z_import) { + return NULL; + } + void *mp = mp_new (8192); + if (!mp) { + zi_ctx_free(z_import); + return NULL; + } + memset(z_import, 0, sizeof(*z_import)); + z_import->pool.ctx = mp; + z_import->worker = worker; + int ret = zi_reset(z_import, 0); + if (ret < 0) { + mp_delete(mp); + zi_ctx_free(z_import); + return NULL; + } + uv_timer_init(z_import->worker->loop, &z_import->timer); + z_import->timer.data = z_import; + z_import->cb = cb; + z_import->cb_param = param; + return z_import; +} + +void zi_free(zone_import_ctx_t *z_import) +{ + z_import->started = false; + z_import->start_timestamp = 0; + z_import->rrset_idx = 0; + mp_delete(z_import->pool.ctx); + z_import->pool.ctx = NULL; + z_import->pool.alloc = NULL; + z_import->worker = NULL; + z_import->cb = NULL; + z_import->cb_param = NULL; + uv_close((uv_handle_t *)&z_import->timer, on_timer_close); +} + +/** @internal Mark rrset that has been already imported + * to avoid repeated import. */ +static inline void zi_rrset_mark_as_imported(knot_rrset_t *rr) +{ + rr->additional = (void *)&RRSET_IS_ALREADY_IMPORTED; +} + +/** @internal Check if rrset is marked as "already imported". + * @return true if marked, false if isn't */ +static inline bool zi_rrset_is_marked_as_imported(knot_rrset_t *rr) +{ + return (rr->additional == &RRSET_IS_ALREADY_IMPORTED); +} + +/** @internal Try to find rrset with given requisites amongst parsed rrsets + * and put it to given packet. If there is RRSIG which covers that rrset, it + * will be added as well. If rrset found and successfully put, it marked as + * "already imported" to avoid repeated import. The same is true for RRSIG. + * @return -1 if failed + * 0 if required record been actually put into the packet + * 1 if required record could not be found */ +static int zi_rrset_find_put(struct zone_import_ctx *z_import, + knot_pkt_t *pkt, const knot_dname_t *owner, + uint16_t class, uint16_t type, uint16_t additional) +{ + if (type != KNOT_RRTYPE_RRSIG) { + /* If required rrset isn't rrsig, these must be the same values */ + additional = type; + } + + char key[KR_RRKEY_LEN]; + int err = kr_rrkey(key, class, owner, type, additional); + if (err <= 0) { + return -1; + } + knot_rrset_t *rr = map_get(&z_import->rrset_indexed, key); + if (!rr) { + return 1; + } + err = knot_pkt_put(pkt, 0, rr, 0); + if (err != KNOT_EOK) { + return -1; + } + zi_rrset_mark_as_imported(rr); + + if (type != KNOT_RRTYPE_RRSIG) { + /* Try to find corresponding rrsig */ + err = zi_rrset_find_put(z_import, pkt, owner, + class, KNOT_RRTYPE_RRSIG, type); + if (err < 0) { + return err; + } + } + + return 0; +} + +/** @internal Try to put given rrset to the given packet. + * If there is RRSIG which covers that rrset, it will be added as well. + * If rrset successfully put in the packet, it marked as + * "already imported" to avoid repeated import. + * The same is true for RRSIG. + * @return -1 if failed + * 0 if required record been actually put into the packet */ +static int zi_rrset_put(struct zone_import_ctx *z_import, knot_pkt_t *pkt, + knot_rrset_t *rr) +{ + assert(rr); + assert(rr->type != KNOT_RRTYPE_RRSIG); + int err = knot_pkt_put(pkt, 0, rr, 0); + if (err != KNOT_EOK) { + return -1; + } + zi_rrset_mark_as_imported(rr); + /* Try to find corresponding RRSIG */ + err = zi_rrset_find_put(z_import, pkt, rr->owner, rr->rclass, + KNOT_RRTYPE_RRSIG, rr->type); + return (err < 0) ? err : 0; +} + +/** @internal Try to put DS & NSEC* for rset->owner to given packet. + * @return -1 if failed; + * 0 if no errors occurred (it doesn't mean + * that records were actually added). */ +static int zi_put_delegation(zone_import_ctx_t *z_import, knot_pkt_t *pkt, + knot_rrset_t *rr) +{ + int err = zi_rrset_find_put(z_import, pkt, rr->owner, + rr->rclass, KNOT_RRTYPE_DS, 0); + if (err == 1) { + /* DS not found, maybe there are NSEC* */ + err = zi_rrset_find_put(z_import, pkt, rr->owner, + rr->rclass, KNOT_RRTYPE_NSEC, 0); + if (err >= 0) { + err = zi_rrset_find_put(z_import, pkt, rr->owner, + rr->rclass, KNOT_RRTYPE_NSEC3, 0); + } + } + return err < 0 ? err : 0; +} + +/** @internal Try to put A & AAAA records for rset->owner to given packet. + * @return -1 if failed; + * 0 if no errors occurred (it doesn't mean + * that records were actually added). */ +static int zi_put_glue(zone_import_ctx_t *z_import, knot_pkt_t *pkt, + knot_rrset_t *rr) +{ + int err = 0; + knot_rdata_t *rdata_i = rr->rrs.rdata; + for (uint16_t i = 0; i < rr->rrs.count; + ++i, rdata_i = knot_rdataset_next(rdata_i)) { + const knot_dname_t *ns_name = knot_ns_name(rdata_i); + err = zi_rrset_find_put(z_import, pkt, ns_name, + rr->rclass, KNOT_RRTYPE_A, 0); + if (err < 0) { + break; + } + + err = zi_rrset_find_put(z_import, pkt, ns_name, + rr->rclass, KNOT_RRTYPE_AAAA, 0); + if (err < 0) { + break; + } + } + return err < 0 ? err : 0; +} + +/** @internal Create query. */ +static knot_pkt_t *zi_query_create(zone_import_ctx_t *z_import, knot_rrset_t *rr) +{ + knot_mm_t *pool = &z_import->pool; + + uint32_t msgid = kr_rand_bytes(2); + + knot_pkt_t *query = knot_pkt_new(NULL, KNOT_WIRE_MAX_PKTSIZE, pool); + if (!query) { + return NULL; + } + + knot_pkt_put_question(query, rr->owner, rr->rclass, rr->type); + knot_pkt_begin(query, KNOT_ANSWER); + knot_wire_set_rd(query->wire); + knot_wire_set_id(query->wire, msgid); + int err = knot_pkt_parse(query, 0); + if (err != KNOT_EOK) { + knot_pkt_free(query); + return NULL; + } + + return query; +} + +/** @internal Import given rrset to cache. + * @return -1 if failed; 0 if success */ +static int zi_rrset_import(zone_import_ctx_t *z_import, knot_rrset_t *rr) +{ + /* Create "pseudo query" which asks for given rrset. */ + knot_pkt_t *query = zi_query_create(z_import, rr); + if (!query) { + return -1; + } + + knot_mm_t *pool = &z_import->pool; + uint8_t *dname = rr->owner; + uint16_t rrtype = rr->type; + uint16_t rrclass = rr->rclass; + + /* Create "pseudo answer". */ + knot_pkt_t *answer = knot_pkt_new(NULL, KNOT_WIRE_MAX_PKTSIZE, pool); + if (!answer) { + knot_pkt_free(query); + return -1; + } + knot_pkt_put_question(answer, dname, rrclass, rrtype); + knot_pkt_begin(answer, KNOT_ANSWER); + + struct kr_qflags options; + memset(&options, 0, sizeof(options)); + options.DNSSEC_WANT = true; + options.NO_MINIMIZE = true; + + /* This call creates internal structures which necessary for + * resolving - qr_task & request_ctx. */ + struct qr_task *task = worker_resolve_start(query, options); + if (!task) { + knot_pkt_free(query); + knot_pkt_free(answer); + return -1; + } + + /* Push query to the request resolve plan. + * Actually query will never been sent to upstream. */ + struct kr_request *request = worker_task_request(task); + struct kr_rplan *rplan = &request->rplan; + struct kr_query *qry = kr_rplan_push(rplan, NULL, dname, rrclass, rrtype); + int state = KR_STATE_FAIL; + bool origin_is_owner = knot_dname_is_equal(rr->owner, z_import->origin); + bool is_referral = (rrtype == KNOT_RRTYPE_NS && !origin_is_owner); + uint32_t msgid = knot_wire_get_id(query->wire); + + qry->id = msgid; + + /* Prepare zonecut. It must have all the necessary requisites for + * successful validation - matched zone name & keys & trust-anchors. */ + kr_zonecut_init(&qry->zone_cut, z_import->origin, pool); + qry->zone_cut.key = z_import->key; + qry->zone_cut.trust_anchor = z_import->ta; + + if (knot_pkt_init_response(request->answer, query) != 0) { + goto cleanup; + } + + /* Since "pseudo" query asks for NS for subzone, + * "pseudo" answer must simulate referral. */ + if (is_referral) { + knot_pkt_begin(answer, KNOT_AUTHORITY); + } + + /* Put target rrset to ANSWER\AUTHORIRY as well as corresponding RRSIG */ + int err = zi_rrset_put(z_import, answer, rr); + if (err != 0) { + goto cleanup; + } + + if (!is_referral) { + knot_wire_set_aa(answer->wire); + } else { + /* Type is KNOT_RRTYPE_NS and owner is not equal to origin. + * It will be "referral" answer and must contain delegation. */ + err = zi_put_delegation(z_import, answer, rr); + if (err < 0) { + goto cleanup; + } + } + + knot_pkt_begin(answer, KNOT_ADDITIONAL); + + if (rrtype == KNOT_RRTYPE_NS) { + /* Try to find glue addresses. */ + err = zi_put_glue(z_import, answer, rr); + if (err < 0) { + goto cleanup; + } + } + + knot_wire_set_id(answer->wire, msgid); + answer->parsed = answer->size; + err = knot_pkt_parse(answer, 0); + if (err != KNOT_EOK) { + goto cleanup; + } + + /* Importing doesn't imply communication with upstream at all. + * "answer" contains pseudo-answer from upstream and must be successfully + * validated in CONSUME stage. If not, something gone wrong. */ + state = kr_resolve_consume(request, NULL, answer); + +cleanup: + + knot_pkt_free(query); + knot_pkt_free(answer); + worker_task_finalize(task, state); + return state == (is_referral ? KR_STATE_PRODUCE : KR_STATE_DONE) ? 0 : -1; +} + +/** @internal Create element in qr_rrsetlist_t rrset_list for + * given node of map_t rrset_sorted. */ +static int zi_mapwalk_preprocess(const char *k, void *v, void *baton) +{ + zone_import_ctx_t *z_import = (zone_import_ctx_t *)baton; + + int ret = array_push_mm(z_import->rrset_sorted, v, kr_memreserve, &z_import->pool); + + return (ret < 0); +} + +/** @internal Iterate over parsed rrsets and try to import each of them. */ +static void zi_zone_process(uv_timer_t* handle) +{ + zone_import_ctx_t *z_import = (zone_import_ctx_t *)handle->data; + + assert(z_import->worker); + + size_t failed = 0; + size_t ns_imported = 0; + size_t other_imported = 0; + + /* At the moment import of root zone only is supported. + * Check the name of the parsed zone. + * TODO - implement importing of arbitrary zone. */ + KR_DNAME_GET_STR(zone_name_str, z_import->origin); + + if (strcmp(".", zone_name_str) != 0) { + kr_log_error("[zimport] unexpected zone name `%s` (root zone expected), fail\n", + zone_name_str); + failed = 1; + goto finish; + } + + if (z_import->rrset_sorted.len <= 0) { + VERBOSE_MSG(NULL, "zone is empty\n"); + goto finish; + } + + /* TA have been found, zone is secured. + * DNSKEY must be somewhere amongst the imported records. Find it. + * TODO - For those zones that provenly do not have TA this step must be skipped. */ + char key[KR_RRKEY_LEN]; + int err = kr_rrkey(key, KNOT_CLASS_IN, z_import->origin, + KNOT_RRTYPE_DNSKEY, KNOT_RRTYPE_DNSKEY); + if (err <= 0) { + failed = 1; + goto finish; + } + + knot_rrset_t *rr_key = map_get(&z_import->rrset_indexed, key); + if (!rr_key) { + /* DNSKEY MUST be here. If not found - fail. */ + kr_log_error("[zimport] DNSKEY not found for `%s`, fail\n", zone_name_str); + failed = 1; + goto finish; + } + z_import->key = rr_key; + + VERBOSE_MSG(NULL, "started: zone: '%s'\n", zone_name_str); + + z_import->start_timestamp = kr_now(); + + /* Import DNSKEY at first step. If any validation problems will appear, + * cancel import of whole zone. */ + KR_DNAME_GET_STR(kname_str, rr_key->owner); + KR_RRTYPE_GET_STR(ktype_str, rr_key->type); + + VERBOSE_MSG(NULL, "importing: name: '%s' type: '%s'\n", + kname_str, ktype_str); + + int res = zi_rrset_import(z_import, rr_key); + if (res != 0) { + VERBOSE_MSG(NULL, "import failed: qname: '%s' type: '%s'\n", + kname_str, ktype_str); + failed = 1; + goto finish; + } + + /* Import all NS records */ + for (size_t i = 0; i < z_import->rrset_sorted.len; ++i) { + knot_rrset_t *rr = z_import->rrset_sorted.at[i]; + + if (rr->type != KNOT_RRTYPE_NS) { + continue; + } + + KR_DNAME_GET_STR(name_str, rr->owner); + KR_RRTYPE_GET_STR(type_str, rr->type); + VERBOSE_MSG(NULL, "importing: name: '%s' type: '%s'\n", + name_str, type_str); + int ret = zi_rrset_import(z_import, rr); + if (ret == 0) { + ++ns_imported; + } else { + VERBOSE_MSG(NULL, "import failed: name: '%s' type: '%s'\n", + name_str, type_str); + ++failed; + } + z_import->rrset_sorted.at[i] = NULL; + } + + /* NS records have been imported as well as relative DS, NSEC* and glue. + * Now import what's left. */ + for (size_t i = 0; i < z_import->rrset_sorted.len; ++i) { + + knot_rrset_t *rr = z_import->rrset_sorted.at[i]; + if (rr == NULL) { + continue; + } + + if (zi_rrset_is_marked_as_imported(rr)) { + continue; + } + + if (rr->type == KNOT_RRTYPE_DNSKEY || rr->type == KNOT_RRTYPE_RRSIG) { + continue; + } + + KR_DNAME_GET_STR(name_str, rr->owner); + KR_RRTYPE_GET_STR(type_str, rr->type); + VERBOSE_MSG(NULL, "importing: name: '%s' type: '%s'\n", + name_str, type_str); + res = zi_rrset_import(z_import, rr); + if (res == 0) { + ++other_imported; + } else { + VERBOSE_MSG(NULL, "import failed: name: '%s' type: '%s'\n", + name_str, type_str); + ++failed; + } + } + + uint64_t elapsed = kr_now() - z_import->start_timestamp; + elapsed = elapsed > UINT_MAX ? UINT_MAX : elapsed; + + VERBOSE_MSG(NULL, "finished in %"PRIu64" ms; zone: `%s`; ns: %zd" + "; other: %zd; failed: %zd\n", + elapsed, zone_name_str, ns_imported, other_imported, failed); + +finish: + + uv_timer_stop(&z_import->timer); + z_import->started = false; + + int import_state = 0; + + if (failed != 0) { + if (ns_imported == 0 && other_imported == 0) { + import_state = -1; + VERBOSE_MSG(NULL, "import failed; zone `%s` \n", zone_name_str); + } else { + import_state = 1; + } + } else { + import_state = 0; + } + + if (z_import->cb != NULL) { + z_import->cb(import_state, z_import->cb_param); + } +} + +/** @internal Store rrset that has been imported to zone import context memory pool. + * @return -1 if failed; 0 if success. */ +static int zi_record_store(zs_scanner_t *s) +{ + if (s->r_data_length > UINT16_MAX) { + /* Due to knot_rrset_add_rdata(..., const uint16_t size, ...); */ + kr_log_error("[zscanner] line %"PRIu64": rdata is too long\n", + s->line_counter); + return -1; + } + + if (knot_dname_size(s->r_owner) != strlen((const char *)(s->r_owner)) + 1) { + kr_log_error("[zscanner] line %"PRIu64 + ": owner name contains zero byte, skip\n", + s->line_counter); + return 0; + } + + zone_import_ctx_t *z_import = (zone_import_ctx_t *)s->process.data; + + knot_rrset_t *new_rr = knot_rrset_new(s->r_owner, s->r_type, s->r_class, + s->r_ttl, &z_import->pool); + if (!new_rr) { + kr_log_error("[zscanner] line %"PRIu64": error creating rrset\n", + s->line_counter); + return -1; + } + int res = knot_rrset_add_rdata(new_rr, s->r_data, s->r_data_length, + &z_import->pool); + if (res != KNOT_EOK) { + kr_log_error("[zscanner] line %"PRIu64": error adding rdata to rrset\n", + s->line_counter); + return -1; + } + + /* Records in zone file may not be grouped by name and RR type. + * Use map to create search key and + * avoid ineffective searches across all the imported records. */ + char key[KR_RRKEY_LEN]; + uint16_t additional_key_field = kr_rrset_type_maysig(new_rr); + + res = kr_rrkey(key, new_rr->rclass, new_rr->owner, new_rr->type, + additional_key_field); + if (res <= 0) { + kr_log_error("[zscanner] line %"PRIu64": error constructing rrkey\n", + s->line_counter); + return -1; + } + + knot_rrset_t *saved_rr = map_get(&z_import->rrset_indexed, key); + if (saved_rr) { + res = knot_rdataset_merge(&saved_rr->rrs, &new_rr->rrs, + &z_import->pool); + } else { + res = map_set(&z_import->rrset_indexed, key, new_rr); + } + if (res != 0) { + kr_log_error("[zscanner] line %"PRIu64": error saving parsed rrset\n", + s->line_counter); + return -1; + } + + return 0; +} + +/** @internal zscanner callback. */ +static int zi_state_parsing(zs_scanner_t *s) +{ + bool empty = true; + while (zs_parse_record(s) == 0) { + switch (s->state) { + case ZS_STATE_DATA: + if (zi_record_store(s) != 0) { + return -1; + } + zone_import_ctx_t *z_import = (zone_import_ctx_t *) s->process.data; + empty = false; + if (s->r_type == 6) { + z_import->origin = knot_dname_copy(s->r_owner, + &z_import->pool); + } + break; + case ZS_STATE_ERROR: + kr_log_error("[zscanner] line: %"PRIu64 + ": parse error; code: %i ('%s')\n", + s->line_counter, s->error.code, + zs_strerror(s->error.code)); + return -1; + case ZS_STATE_INCLUDE: + kr_log_error("[zscanner] line: %"PRIu64 + ": INCLUDE is not supported\n", + s->line_counter); + return -1; + case ZS_STATE_EOF: + case ZS_STATE_STOP: + if (empty) { + kr_log_error("[zimport] empty zone file\n"); + return -1; + } + if (!((zone_import_ctx_t *) s->process.data)->origin) { + kr_log_error("[zimport] zone file doesn't contain SOA record\n"); + return -1; + } + return (s->error.counter == 0) ? 0 : -1; + default: + kr_log_error("[zscanner] line: %"PRIu64 + ": unexpected parse state: %i\n", + s->line_counter, s->state); + return -1; + } + } + + return -1; +} + +int zi_zone_import(struct zone_import_ctx *z_import, + const char *zone_file, const char *origin, + uint16_t rclass, uint32_t ttl) +{ + assert (z_import != NULL && "[zimport] empty parameter"); + assert (z_import->worker != NULL && "[zimport] invalid parameter\n"); + assert (zone_file != NULL && "[zimport] empty parameter\n"); + + zs_scanner_t *s = malloc(sizeof(zs_scanner_t)); + if (s == NULL) { + kr_log_error("[zscanner] error creating instance of zone scanner (malloc() fails)\n"); + return -1; + } + + /* zs_init(), zs_set_input_file(), zs_set_processing() returns -1 in case of error, + * so don't print error code as it meaningless. */ + int res = zs_init(s, origin, rclass, ttl); + if (res != 0) { + kr_log_error("[zscanner] error initializing zone scanner instance, error: %i (%s)\n", + s->error.code, zs_strerror(s->error.code)); + free(s); + return -1; + } + + res = zs_set_input_file(s, zone_file); + if (res != 0) { + kr_log_error("[zscanner] error opening zone file `%s`, error: %i (%s)\n", + zone_file, s->error.code, zs_strerror(s->error.code)); + zs_deinit(s); + free(s); + return -1; + } + + /* Don't set processing and error callbacks as we don't use automatic parsing. + * Parsing as well error processing will be performed in zi_state_parsing(). + * Store pointer to zone import context for further use. */ + if (zs_set_processing(s, NULL, NULL, (void *)z_import) != 0) { + kr_log_error("[zscanner] zs_set_processing() failed for zone file `%s`, " + "error: %i (%s)\n", + zone_file, s->error.code, zs_strerror(s->error.code)); + zs_deinit(s); + free(s); + return -1; + } + + uint64_t elapsed = 0; + int ret = zi_reset(z_import, 4096); + if (ret == 0) { + z_import->started = true; + z_import->start_timestamp = kr_now(); + VERBOSE_MSG(NULL, "[zscanner] started; zone file `%s`\n", + zone_file); + ret = zi_state_parsing(s); + if (ret == 0) { + /* Try to find TA for worker->z_import.origin. */ + map_t *trust_anchors = &z_import->worker->engine->resolver.trust_anchors; + knot_rrset_t *rr = kr_ta_get(trust_anchors, z_import->origin); + if (rr) { + z_import->ta = rr; + } else { + /* For now - fail. + * TODO - query DS and continue after answer had been obtained. */ + KR_DNAME_GET_STR(zone_name_str, z_import->origin); + kr_log_error("[zimport] no TA found for `%s`, fail\n", zone_name_str); + ret = 1; + } + elapsed = kr_now() - z_import->start_timestamp; + elapsed = elapsed > UINT_MAX ? UINT_MAX : elapsed; + } + } + zs_deinit(s); + free(s); + + if (ret != 0) { + kr_log_error("[zscanner] error parsing zone file `%s`\n", zone_file); + z_import->started = false; + return ret; + } + + VERBOSE_MSG(NULL, "[zscanner] finished in %"PRIu64" ms; zone file `%s`\n", + elapsed, zone_file); + map_walk(&z_import->rrset_indexed, zi_mapwalk_preprocess, z_import); + + /* Zone have been parsed already, so start the import. */ + uv_timer_start(&z_import->timer, zi_zone_process, + ZONE_IMPORT_PAUSE, ZONE_IMPORT_PAUSE); + + return 0; +} + +bool zi_import_started(struct zone_import_ctx *z_import) +{ + return z_import ? z_import->started : false; +} diff --git a/utils/watcher/zimport.h b/utils/watcher/zimport.h new file mode 100644 index 000000000..57b246e95 --- /dev/null +++ b/utils/watcher/zimport.h @@ -0,0 +1,68 @@ +/* Copyright (C) 2018 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +#pragma once + +#include + +struct worker_ctx; +/** Zone import context (opaque). */ +struct zone_import_ctx; + +/** + * Completion callback + * + * @param state -1 - fail + * 0 - success + * 1 - success, but there are non-critical errors + * @param pointer to user data + */ +typedef void (*zi_callback)(int state, void *param); + +/** + * Allocate and initialize zone import context. + * + * @param worker pointer to worker state + * @return NULL or pointer to zone import context. + */ +struct zone_import_ctx *zi_allocate(struct worker_ctx *worker, + zi_callback cb, void *param); + +/** Free zone import context. */ +void zi_free(struct zone_import_ctx *z_import); + +/** + * Import zone from file. + * + * @note only root zone import is supported; origin must be NULL or "." + * @param z_import pointer to zone import context + * @param zone_file zone file name + * @param origin default origin + * @param rclass default class + * @param ttl default ttl + * @return 0 or an error code + */ +int zi_zone_import(struct zone_import_ctx *z_import, + const char *zone_file, const char *origin, + uint16_t rclass, uint32_t ttl); + +/** + * Check if import already in process. + * + * @param z_import pointer to zone import context. + * @return true if import already in process; false otherwise. + */ +bool zi_import_started(struct zone_import_ctx *z_import); -- 2.47.3