From e7c5c102d0eb83aeb42df57352e3c0dad7a2c80b Mon Sep 17 00:00:00 2001 From: Grigorii Demidov Date: Tue, 12 Sep 2017 13:52:10 +0200 Subject: [PATCH] daemon: reuse outbound TCP connections if possible; TLS over outbound TCP connection --- daemon/bindings.c | 103 +++ daemon/io.c | 44 +- daemon/io.h | 19 +- daemon/main.c | 1 + daemon/network.c | 2 + daemon/network.h | 1 + daemon/tls.c | 477 +++++++++++ daemon/tls.h | 47 +- daemon/worker.c | 1635 ++++++++++++++++++++++++++++--------- daemon/worker.h | 62 +- lib/layer/iterate.c | 2 +- lib/resolve.c | 1 - lib/utils.c | 103 +++ lib/utils.h | 38 +- modules/policy/policy.lua | 29 +- 15 files changed, 2145 insertions(+), 419 deletions(-) diff --git a/daemon/bindings.c b/daemon/bindings.c index 140ec7df3..7fcc9a39d 100644 --- a/daemon/bindings.c +++ b/daemon/bindings.c @@ -398,6 +398,107 @@ static int net_tls(lua_State *L) return 1; } +static int print_tls_param(const char *key, void *val, void *data) +{ + if (!val) { + return 0; + } + + struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)val; + + lua_State *L = (lua_State *)data; + + lua_newtable(L); + lua_newtable(L); + + lua_newtable(L); + for (size_t i = 0; i < entry->pins.len; ++i) { + lua_pushnumber(L, i + 1); + lua_pushstring(L, entry->pins.at[i]); + lua_settable(L, -3); + } + lua_setfield(L, -2, "pins"); + lua_newtable(L); + for (size_t i = 0; i < entry->ca_files.len; ++i) { + lua_pushnumber(L, i + 1); + lua_pushstring(L, entry->ca_files.at[i]); + lua_settable(L, -3); + } + lua_setfield(L, -2, "ca files"); + lua_setfield(L, -2, key); + + return 0; +} + +static int print_tls_client_params(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + if (!engine) { + return 0; + } + struct network *net = &engine->net; + if (!net) { + return 0; + } + if (net->tls_client_params.root == 0 ) { + return 0; + } + map_walk(&net->tls_client_params, print_tls_param, (void *)L); + return 1; +} + + +static int net_tls_client(lua_State *L) +{ + struct engine *engine = engine_luaget(L); + if (!engine) { + return 0; + } + struct network *net = &engine->net; + if (!net) { + return 0; + } + + /* Only return current credentials. */ + if (lua_gettop(L) == 0) { + return print_tls_client_params(L); + } + + const char *full_addr = NULL; + const char *ca_file = NULL; + const char *pin = NULL; + if ((lua_gettop(L) == 1) && lua_isstring(L, 1)) { + full_addr = lua_tostring(L, 1); + } else if ((lua_gettop(L) == 3) && lua_isstring(L, 1) && lua_isstring(L, 2) && lua_isstring(L, 3)) { + full_addr = lua_tostring(L, 1); + ca_file = lua_tostring(L, 2); + pin = lua_tostring(L, 3); + } else { + format_error(L, "net.tls_client either takes one parameter (\"address\") either takes three ones: (\"address\", \"ca_file\", \"pin\")"); + lua_error(L); + } + + char addr[INET6_ADDRSTRLEN]; + uint16_t port = 0; + if (kr_straddr_split(full_addr, addr, sizeof(addr), &port) != kr_ok()) { + format_error(L, "invalid IP address"); + lua_error(L); + } + + if (port == 0) { + port = 53; + } + + int r = tls_client_params_set(&net->tls_client_params, addr, port, ca_file, pin); + if (r != 0) { + lua_pushstring(L, strerror(ENOMEM)); + lua_error(L); + } + + lua_pushboolean(L, true); + return 1; +} + static int net_tls_padding(lua_State *L) { struct engine *engine = engine_luaget(L); @@ -508,6 +609,8 @@ int lib_net(lua_State *L) { "bufsize", net_bufsize }, { "tcp_pipeline", net_pipeline }, { "tls", net_tls }, + { "tls_server", net_tls }, + { "tls_client", net_tls_client }, { "tls_padding", net_tls_padding }, { "outgoing_v4", net_outgoing_v4 }, { "outgoing_v6", net_outgoing_v6 }, diff --git a/daemon/io.c b/daemon/io.c index 5ebd3ac87..112b19611 100644 --- a/daemon/io.c +++ b/daemon/io.c @@ -48,15 +48,18 @@ static void check_bufsize(uv_handle_t* handle) static void session_clear(struct session *s) { - assert(s->outgoing || s->tasks.len == 0); + assert(s->tasks.len == 0 && s->waiting.len == 0); array_clear(s->tasks); + array_clear(s->waiting); tls_free(s->tls_ctx); + tls_client_ctx_free(s->tls_client_ctx); memset(s, 0, sizeof(*s)); } void session_free(struct session *s) { if (s) { + assert(s->tasks.len == 0 && s->waiting.len == 0); session_clear(s); free(s); } @@ -89,6 +92,8 @@ static void session_release(struct worker_ctx *worker, uv_handle_t *handle) if (!s) { return; } + assert(s->waiting.len == 0 && s->tasks.len == 0); + assert(s->buffering == NULL); if (!s->outgoing && handle->type == UV_TCP) { worker_end_tcp(worker, handle); /* to free the buffering task */ } @@ -158,8 +163,10 @@ static int udp_bind_finalize(uv_handle_t *handle) { check_bufsize((uv_handle_t *)handle); /* Handle is already created, just create context. */ - handle->data = session_new(); - assert(handle->data); + struct session *session = session_new(); + assert(session); + session->handle = handle; + handle->data = session; return io_start_read((uv_handle_t *)handle); } @@ -189,20 +196,14 @@ int udp_bindfd(uv_udp_t *handle, int fd) return udp_bind_finalize((uv_handle_t *)handle); } -static void tcp_timeout(uv_handle_t *timer) -{ - uv_handle_t *handle = timer->data; - uv_close(handle, io_free); -} - static void tcp_timeout_trigger(uv_timer_t *timer) { - uv_handle_t *handle = timer->data; - struct session *session = handle->data; + struct session *session = timer->data; if (session->tasks.len > 0) { uv_timer_again(timer); } else { - uv_close((uv_handle_t *)timer, tcp_timeout); + uv_timer_stop(timer); + worker_session_close(session); } } @@ -210,12 +211,16 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) { uv_loop_t *loop = handle->loop; struct session *s = handle->data; + if (s->closing) { + return; + } struct worker_ctx *worker = loop->data; /* TCP pipelining is rather complicated and requires cooperation from the worker * so the whole message reassembly and demuxing logic is inside worker */ int ret = 0; if (s->has_tls) { - ret = tls_process(worker, handle, (const uint8_t *)buf->base, nread); + ret = s->outgoing ? tls_client_process(worker, handle, (const uint8_t *)buf->base, nread) : + tls_process(worker, handle, (const uint8_t *)buf->base, nread); } else { ret = worker_process_tcp(worker, handle, (const uint8_t *)buf->base, nread); } @@ -226,7 +231,7 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) if (!s->outgoing && !uv_is_closing((uv_handle_t *)&s->timeout)) { uv_timer_stop(&s->timeout); if (s->tasks.len == 0) { - uv_close((uv_handle_t *)&s->timeout, tcp_timeout); + worker_session_close(s); } else { /* If there are tasks running, defer until they finish. */ uv_timer_start(&s->timeout, tcp_timeout_trigger, 1, KR_CONN_RTT_MAX/2); } @@ -265,7 +270,7 @@ static void _tcp_accept(uv_stream_t *master, int status, bool tls) } uv_timer_t *timer = &session->timeout; uv_timer_init(master->loop, timer); - timer->data = client; + timer->data = session; uv_timer_start(timer, tcp_timeout_trigger, KR_CONN_RTT_MAX/2, KR_CONN_RTT_MAX/2); io_start_read((uv_handle_t *)client); } @@ -379,8 +384,12 @@ void io_create(uv_loop_t *loop, uv_handle_t *handle, int type) } struct worker_ctx *worker = loop->data; - handle->data = session_borrow(worker); - assert(handle->data); + struct session *session = session_borrow(worker); + assert(session); + session->handle = handle; + handle->data = session; + session->timeout.data = session; + uv_timer_init(worker->loop, &session->timeout); } void io_deinit(uv_handle_t *handle) @@ -388,6 +397,7 @@ void io_deinit(uv_handle_t *handle) if (!handle) { return; } + struct session *session = handle->data; uv_loop_t *loop = handle->loop; if (loop && loop->data) { struct worker_ctx *worker = loop->data; diff --git a/daemon/io.h b/daemon/io.h index dc040fe34..24c0c26e7 100644 --- a/daemon/io.h +++ b/daemon/io.h @@ -18,22 +18,35 @@ #include #include +#include #include "lib/generic/array.h" +#include "daemon/worker.h" -struct qr_task; struct tls_ctx_t; +struct tls_client_ctx_t; /* Per-session (TCP or UDP) persistent structure, * that exists between remote counterpart and a local socket. */ struct session { - bool outgoing; + bool outgoing; /**< True: to upstream; false: from a client. */ bool throttled; bool has_tls; + bool connected; + bool closing; + union inaddr peer; + uv_handle_t *handle; uv_timer_t timeout; struct qr_task *buffering; /**< Worker buffers the incomplete TCP query here. */ struct tls_ctx_t *tls_ctx; - array_t(struct qr_task *) tasks; + struct tls_client_ctx_t *tls_client_ctx; + + uint8_t msg_hdr[4]; /**< Buffer for DNS message header. */ + ssize_t msg_hdr_idx; /**< The number of bytes in msg_hdr filled so far. */ + + qr_tasklist_t tasks; + qr_tasklist_t waiting; + ssize_t bytes_to_skip; }; void session_free(struct session *s); diff --git a/daemon/main.c b/daemon/main.c index 276f0fb7f..3f22a28d1 100644 --- a/daemon/main.c +++ b/daemon/main.c @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include diff --git a/daemon/network.c b/daemon/network.c index 40b2ac011..e181cae70 100644 --- a/daemon/network.c +++ b/daemon/network.c @@ -51,6 +51,7 @@ void network_init(struct network *net, uv_loop_t *loop) if (net != NULL) { net->loop = loop; net->endpoints = map_make(); + net->tls_client_params = map_make(); } } @@ -106,6 +107,7 @@ void network_deinit(struct network *net) map_walk(&net->endpoints, free_key, 0); map_clear(&net->endpoints); tls_credentials_free(net->tls_credentials); + tls_client_params_free(&net->tls_client_params); net->tls_credentials = NULL; } } diff --git a/daemon/network.h b/daemon/network.h index 5227265f6..c562b6413 100644 --- a/daemon/network.h +++ b/daemon/network.h @@ -46,6 +46,7 @@ struct network { uv_loop_t *loop; map_t endpoints; struct tls_credentials *tls_credentials; + map_t tls_client_params; }; void network_init(struct network *net, uv_loop_t *loop); diff --git a/daemon/tls.c b/daemon/tls.c index e1ed1161d..b5a79b051 100644 --- a/daemon/tls.c +++ b/daemon/tls.c @@ -54,6 +54,19 @@ struct tls_ctx_t { struct tls_credentials *credentials; }; +struct tls_client_ctx_t { + gnutls_session_t tls_session; + tls_client_hs_state_t handshake_state; + + struct session *session; + tls_handshake_cb handshake_cb; + const uint8_t *buf; + ssize_t nread; + ssize_t consumed; + uint8_t recv_buf[4096]; + const struct tls_client_paramlist_entry *params; +}; + /** @internal Debugging facility. */ #ifdef DEBUG #define DEBUG_MSG(fmt...) fprintf(stderr, "[tls] " fmt) @@ -61,6 +74,8 @@ struct tls_ctx_t { #define DEBUG_MSG(fmt...) #endif +static int client_verify_certificate(gnutls_session_t tls_session); + static ssize_t kres_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len) { struct tls_ctx_t *t = (struct tls_ctx_t *)h; @@ -504,4 +519,466 @@ void tls_credentials_free(struct tls_credentials *tls_credentials) { free(tls_credentials); } +static int client_paramlist_entry_clear(const char *k, void *v, void *baton) +{ + struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)v; + + while (entry->ca_files.len > 0) { + if (entry->ca_files.at[0] != NULL) { + free((void *)entry->ca_files.at[0]); + } + array_del(entry->ca_files, 0); + } + + while (entry->pins.len > 0) { + if (entry->pins.at[0] != NULL) { + free((void *)entry->pins.at[0]); + } + array_del(entry->pins, 0); + } + + array_clear(entry->ca_files); + array_clear(entry->pins); + + if (entry->credentials) { + gnutls_certificate_free_credentials(entry->credentials); + } + + free(entry); + + return 0; +} + +int tls_client_params_set(map_t *tls_client_paramlist, + const char *addr, uint16_t port, + const char *ca_file, const char *pin) +{ + if (!tls_client_paramlist || !addr) { + return kr_error(EINVAL); + } + + char key[INET6_ADDRSTRLEN + 6]; + size_t keylen = sizeof(key); + if (kr_straddr_join(addr, port, key, &keylen) != kr_ok()) { + kr_log_error("[tls client] warning: '%s' is not a valid ip address, ignoring\n", addr); + return kr_ok(); + } + + bool is_first_entry = false; + struct tls_client_paramlist_entry *entry = map_get(tls_client_paramlist, key); + if (entry == NULL) { + entry = calloc(1, sizeof(struct tls_client_paramlist_entry)); + if (entry == NULL) { + return kr_error(ENOMEM); + } + is_first_entry = true; + int ret = gnutls_certificate_allocate_credentials(&entry->credentials); + if (ret != GNUTLS_E_SUCCESS) { + free(entry); + kr_log_error("[tls client] error: gnutls_certificate_allocate_credentials() fails (%s)\n", + gnutls_strerror_name(ret)); + return kr_error(ENOMEM); + } + gnutls_certificate_set_verify_function(entry->credentials, client_verify_certificate); + } + + int ret = kr_ok(); + if (ca_file && ca_file[0] != 0) { + bool already_exists = false; + for (size_t i = 0; i < entry->ca_files.len; ++i) { + if (strcmp(entry->ca_files.at[i], ca_file) == 0) { + kr_log_error("[tls client] error: ca file for address %s already was set, ignoring\n", key); + already_exists = true; + break; + } + } + if (!already_exists) { + const char *value = strdup(ca_file); + if (!value) { + ret = kr_error(ENOMEM); + } else if (array_push(entry->ca_files, value) < 0) { + free ((void *)value); + ret = kr_error(ENOMEM); + } else { + int res = gnutls_certificate_set_x509_trust_file(entry->credentials, value, + GNUTLS_X509_FMT_PEM); + if (res < 0) { + kr_log_error("[tls client], failed to import certificate file '%s' (%s)\n", + value, gnutls_strerror_name(res)); + /* value will be freed at cleanup */ + ret = kr_error(EINVAL); + } + } + } + } + + if ((ret == kr_ok()) && pin && pin[0] != 0) { + for (size_t i = 0; i < entry->pins.len; ++i) { + if (strcmp(entry->pins.at[i], pin) == 0) { + kr_log_error("[tls client] warning: pin for address %s already was set, ignoring\n", key); + return kr_ok(); + } + } + const void *value = strdup(pin); + if (!value) { + ret = kr_error(ENOMEM); + } else if (array_push(entry->pins, value) < 0) { + free ((void *)value); + ret = kr_error(ENOMEM); + } + } + + if ((ret == kr_ok()) && is_first_entry) { + bool fail = (map_set(tls_client_paramlist, key, entry) != 0); + if (fail) { + ret = kr_error(ENOMEM); + } + } + + if ((ret != kr_ok()) && is_first_entry) { + client_paramlist_entry_clear(NULL, (void *)entry, NULL); + } + + return ret; +} + +int tls_client_params_free(map_t *tls_client_paramlist) +{ + if (!tls_client_paramlist) { + return kr_error(EINVAL); + } + + map_walk(tls_client_paramlist, client_paramlist_entry_clear, NULL); + map_clear(tls_client_paramlist); + + return kr_ok(); +} + +static int client_verify_certificate(gnutls_session_t tls_session) +{ + struct tls_client_ctx_t *ctx = gnutls_session_get_ptr(tls_session); + assert(ctx->params != NULL); + + if (ctx->params->pins.len == 0 && ctx->params->ca_files.len) { + return GNUTLS_E_SUCCESS; + } + + gnutls_certificate_type_t cert_type = gnutls_certificate_type_get(tls_session); + if (cert_type != GNUTLS_CRT_X509) { + kr_log_error("[tls_client] invalid certificate type %i has been received\n", + cert_type); + return GNUTLS_E_CERTIFICATE_ERROR; + } + unsigned int cert_list_size = 0; + const gnutls_datum_t *cert_list = + gnutls_certificate_get_peers(tls_session, &cert_list_size); + if (cert_list == NULL || cert_list_size == 0) { + kr_log_error("[tls_client] empty certificate list\n"); + return GNUTLS_E_CERTIFICATE_ERROR; + } + + if (ctx->params->pins.len == 0) { + DEBUG_MSG("[tls_client] skipping certificate PIN check\n"); + goto skip_pins; + } + + for (int i = 0; i < cert_list_size; i++) { + gnutls_x509_crt_t cert; + int ret = gnutls_x509_crt_init(&cert); + if (ret != GNUTLS_E_SUCCESS) { + return ret; + } + + ret = gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER); + if (ret != GNUTLS_E_SUCCESS) { + gnutls_x509_crt_deinit(cert); + return ret; + } + + char cert_pin[PINLEN] = { 0 }; + ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin)); + + gnutls_x509_crt_deinit(cert); + + if (ret != GNUTLS_E_SUCCESS) { + return ret; + } + + DEBUG_MSG("[tls_client] received pin : %s\n", cert_pin); + for (size_t i = 0; i < ctx->params->pins.len; ++i) { + const char *pin = ctx->params->pins.at[i]; + bool match = (strcmp(cert_pin, pin) == 0); + DEBUG_MSG("[tls_client] configured pin: %s matches? %s\n", + pin, match ? "yes" : "no"); + if (match) { + return GNUTLS_E_SUCCESS; + } + } + } + +skip_pins: + + if (ctx->params->ca_files.len == 0) { + DEBUG_MSG("[tls_client] skipping certificate verification\n"); + return GNUTLS_E_SUCCESS; + } + + gnutls_typed_vdata_st data[2] = { + { .type = GNUTLS_DT_KEY_PURPOSE_OID, + .data = (void *)GNUTLS_KP_TLS_WWW_SERVER } + }; + size_t data_count = 1; + unsigned int status; + int ret = gnutls_certificate_verify_peers(ctx->tls_session, data, data_count, &status); + if (ret != GNUTLS_E_SUCCESS) { + kr_log_error("[tls_client] failed to verify peer certificate\n"); + return GNUTLS_E_CERTIFICATE_ERROR; + } + + return GNUTLS_E_SUCCESS; +} + +static ssize_t kres_gnutls_client_push(gnutls_transport_ptr_t h, const void *buf, size_t len) +{ + struct tls_client_ctx_t *t = (struct tls_client_ctx_t *)h; + const uv_buf_t ub = {(void *)buf, len}; + + DEBUG_MSG("[tls_client] push %zu <%p>\n", len, h); + if (t == NULL) { + errno = EFAULT; + return -1; + } + + int ret = uv_try_write((uv_stream_t *)t->session->handle, &ub, 1); + if (ret > 0) { + return (ssize_t) ret; + } + if (ret == UV_EAGAIN) { + errno = EAGAIN; + } else { + kr_log_error("[tls_client] uv_try_write: %s\n", uv_strerror(ret)); + errno = EIO; + } + return -1; +} + + +static ssize_t kres_gnutls_client_pull(gnutls_transport_ptr_t h, void *buf, size_t len) +{ + struct tls_client_ctx_t *t = (struct tls_client_ctx_t *)h; + assert(t != NULL); + + ssize_t avail = t->nread - t->consumed; + DEBUG_MSG("[tls] pull wanted: %zu available: %zu\n", len, avail); + if (t->nread <= t->consumed) { + errno = EAGAIN; + return -1; + } + + ssize_t transfer = MIN(avail, len); + memcpy(buf, t->buf + t->consumed, transfer); + t->consumed += transfer; + return transfer; +} + +int tls_client_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt) +{ + if (!pkt || !handle || !handle->data) { + return kr_error(EINVAL); + } + + struct session *session = handle->data; + const uint16_t pkt_size = htons(pkt->size); + struct tls_client_ctx_t *ctx = session->tls_client_ctx; + if (!ctx) { + kr_log_error("[tls_client] no tls context on push\n"); + return kr_error(ENOENT); + } + + gnutls_record_cork(ctx->tls_session); + ssize_t count = 0; + if ((count = gnutls_record_send(ctx->tls_session, &pkt_size, sizeof(pkt_size)) < 0) || + (count = gnutls_record_send(ctx->tls_session, pkt->wire, pkt->size) < 0)) { + kr_log_error("[tls_client] gnutls_record_send failed: %s (%zd)\n", gnutls_strerror_name(count), count); + return kr_error(EIO); + } + + ssize_t submitted = 0; + do { + count = gnutls_record_uncork(ctx->tls_session, 0); + if (count < 0) { + if (gnutls_error_is_fatal(count)) { + kr_log_error("[tls_client] gnutls_record_uncork failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + return kr_error(EIO); + } + } else { + submitted += count; + if (count == 0 && submitted != sizeof(pkt_size) + pkt->size) { + kr_log_error("[tls_client] gnutls_record_uncork didn't send all data: %s (%zd)\n", + gnutls_strerror_name(count), count); + return kr_error(EIO); + } + } + } while (submitted != sizeof(pkt_size) + pkt->size); + + return kr_ok(); +} + +int tls_client_process(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *buf, ssize_t nread) +{ + struct session *session = handle->data; + struct tls_client_ctx_t *ctx = session->tls_client_ctx; + if (!ctx) { + return kr_error(ENOSYS); + } + + assert(ctx->handshake_state == TLS_HS_IN_PROGRESS || + ctx->handshake_state == TLS_HS_DONE); + + ctx->buf = buf; + ctx->nread = nread; + ctx->session = session; + ctx->consumed = 0; + + if (ctx->handshake_state == TLS_HS_IN_PROGRESS) { + int err = gnutls_handshake(ctx->tls_session); + if (err == GNUTLS_E_SUCCESS) { + ctx->handshake_state = TLS_HS_DONE; + } else if (err == GNUTLS_E_AGAIN) { + return 0; + } else if (err < 0 && gnutls_error_is_fatal(err)) { + kr_log_error("[tls_client] gnutls_handshake failed: %s (%d)\n", + gnutls_strerror_name(err), err); + if (ctx->handshake_cb) { + ctx->handshake_cb(ctx->session, -1); + } + return kr_error(err); + } + if (ctx->handshake_cb) { + ctx->handshake_cb(ctx->session, 0); + } + DEBUG_MSG("[tls_client] TLS handshake with %s has completed.\n", kr_straddr(&session->peer.ip)); + } + + int submitted = 0; + while (true) { + ssize_t count = gnutls_record_recv(ctx->tls_session, ctx->recv_buf, sizeof(ctx->recv_buf)); + if (count == GNUTLS_E_AGAIN) { + break; /* No data available */ + } else if (count == GNUTLS_E_INTERRUPTED) { + continue; /* Try reading again */ + } else if (count < 0) { + kr_log_error("[tls_client] gnutls_record_recv failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + return kr_error(EIO); + } + DEBUG_MSG("[tls_client] submitting %zd data to worker\n", count); + int ret = worker_process_tcp(worker, handle, ctx->recv_buf, count); + if (ret < 0) { + return ret; + } + submitted += ret; + } + return submitted; +} + +struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_entry *entry) +{ + struct tls_client_ctx_t *ctx = calloc(1, sizeof (struct tls_client_ctx_t)); + if (!ctx) { + return NULL; + } + + int ret = gnutls_init(&ctx->tls_session, GNUTLS_CLIENT | GNUTLS_NONBLOCK); + if (ret != GNUTLS_E_SUCCESS) { + tls_client_ctx_free(ctx); + return NULL; + } + + ret = gnutls_set_default_priority(ctx->tls_session); + if (ret != GNUTLS_E_SUCCESS) { + tls_client_ctx_free(ctx); + return NULL; + } + + ret = gnutls_credentials_set(ctx->tls_session, GNUTLS_CRD_CERTIFICATE, + entry->credentials); + if (ret != GNUTLS_E_SUCCESS) { + tls_client_ctx_free(ctx); + return NULL; + } + + gnutls_transport_set_pull_function(ctx->tls_session, kres_gnutls_client_pull); + gnutls_transport_set_push_function(ctx->tls_session, kres_gnutls_client_push); + gnutls_transport_set_ptr(ctx->tls_session, ctx); + return ctx; +} + +void tls_client_ctx_free(struct tls_client_ctx_t *ctx) +{ + if (ctx == NULL) { + return; + } + + if (ctx->session != NULL) { + gnutls_deinit(ctx->tls_session); + } + + free (ctx); +} + +int tls_client_connect_start(struct tls_client_ctx_t *ctx, + struct session *session, + tls_handshake_cb handshake_cb) +{ + if (session == NULL || ctx == NULL) { + return kr_error(EINVAL); + } + + assert(session->outgoing && session->handle->type == UV_TCP); + + gnutls_session_set_ptr(ctx->tls_session, ctx); + gnutls_handshake_set_timeout(ctx->tls_session, 5000); + session->tls_client_ctx = ctx; + ctx->handshake_cb = handshake_cb; + ctx->handshake_state = TLS_HS_IN_PROGRESS; + ctx->session = session; + + int ret = gnutls_handshake(ctx->tls_session); + if (ret == GNUTLS_E_SUCCESS) { + return kr_ok(); + } else if (gnutls_error_is_fatal(ret) != 0) { + kr_log_error("[tls client] handshake failed (%s)\n", gnutls_strerror(ret)); + return kr_error(ECONNABORTED); + } + return kr_error(EAGAIN); +} + +void tls_client_close(struct tls_client_ctx_t *ctx) +{ + if (ctx == NULL || ctx->session == NULL) { + return; + } + + if (ctx->handshake_state == TLS_HS_DONE) { + gnutls_bye(ctx->tls_session, GNUTLS_SHUT_RDWR); + } +} + +tls_client_hs_state_t tls_client_get_hs_state(const struct tls_client_ctx_t *ctx) +{ + return ctx->handshake_state; +} + +int tls_client_ctx_set_params(struct tls_client_ctx_t *ctx, + const struct tls_client_paramlist_entry *entry) +{ + if (!ctx) { + return kr_error(EINVAL); + } + ctx->params = entry; + return kr_ok(); +} + #undef DEBUG_MSG diff --git a/daemon/tls.h b/daemon/tls.h index 385f8fe39..35cd08513 100644 --- a/daemon/tls.h +++ b/daemon/tls.h @@ -20,11 +20,13 @@ #include #include #include "lib/defines.h" +#include "lib/generic/array.h" +#include "lib/generic/map.h" #define MAX_TLS_PADDING KR_EDNS_PAYLOAD struct tls_ctx_t; -struct tls_credentials; +struct tls_client_ctx_t; struct tls_credentials { int count; char *tls_cert; @@ -34,6 +36,20 @@ struct tls_credentials { char *ephemeral_servicename; }; +struct tls_client_paramlist_entry { + array_t(const char *) ca_files; + array_t(const char *) pins; + gnutls_certificate_credentials_t credentials; +}; + +typedef enum tls_client_hs_state { + TLS_HS_NOT_STARTED = 0, + TLS_HS_IN_PROGRESS, + TLS_HS_DONE +} tls_client_hs_state_t; + +typedef int (*tls_handshake_cb) (struct session *session, int status); + /*! Create an empty TLS context in query context */ struct tls_ctx_t* tls_new(struct worker_ctx *worker); @@ -66,3 +82,32 @@ void tls_credentials_log_pins(struct tls_credentials *tls_credentials); /*! Generate new ephemeral TLS credentials. */ struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine); + +/*! Set TLS authentication parameters for given address. */ +int tls_client_params_set(map_t *tls_client_paramlist, + const char *addr, uint16_t port, + const char *ca_file, const char *pin); + +/*! Free TLS authentication parameters. */ +int tls_client_params_free(map_t *tls_client_paramlist); + +/*! Allocate new client TLS context */ +struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_entry *entry); + +int tls_client_process(struct worker_ctx *worker, uv_stream_t *handle, + const uint8_t *buf, ssize_t nread); + +/*! Free client TLS context */ +void tls_client_ctx_free(struct tls_client_ctx_t *ctx); + +int tls_client_connect_start(struct tls_client_ctx_t *ctx, struct session *session, + tls_handshake_cb handshake_cb); + +void tls_client_close(struct tls_client_ctx_t *ctx); + +int tls_client_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt); + +tls_client_hs_state_t tls_client_get_hs_state(const struct tls_client_ctx_t *ctx); + +int tls_client_ctx_set_params(struct tls_client_ctx_t *ctx, + const struct tls_client_paramlist_entry *entry); \ No newline at end of file diff --git a/daemon/worker.c b/daemon/worker.c index 93fdb6903..a7c96f739 100644 --- a/daemon/worker.c +++ b/daemon/worker.c @@ -27,6 +27,7 @@ #include #include #include +#include #include "lib/utils.h" #include "lib/layer.h" #include "daemon/worker.h" @@ -37,6 +38,47 @@ #define VERBOSE_MSG(qry, fmt...) QRVERBOSE(qry, "wrkr", fmt) +/** Client request state. */ +struct request_ctx +{ + struct kr_request req; + struct { + union inaddr addr; + union inaddr dst_addr; + /* uv_handle_t *handle; */ + + /** NULL if the request didn't come over network. */ + struct session *session; + } source; + struct worker_ctx *worker; + qr_tasklist_t tasks; +}; + +/** Query resolution task. */ +struct qr_task +{ + struct request_ctx *ctx; + knot_pkt_t *pktbuf; + qr_tasklist_t waiting; + uv_handle_t *pending[MAX_PENDING]; + uint16_t pending_count; + uint16_t addrlist_count; + uint16_t addrlist_turn; + uint16_t timeouts; + uint16_t iter_count; + uint16_t bytes_remaining; + struct sockaddr *addrlist; + worker_cb_t on_complete; + void *baton; + uint32_t refs; + bool finished : 1; + bool leading : 1; +}; + + +int32_t tcp_connected = 0; +int32_t tcp_waiting = 0; + /* @internal Union of various libuv objects for freelist. */ struct req { @@ -59,11 +101,46 @@ struct req #define qr_task_unref(task) \ do { if (--(task)->refs == 0) { qr_task_free(task); } } while (0) #define qr_valid_handle(task, checked) \ - (!uv_is_closing((checked)) || (task)->source.handle == (checked)) + (!uv_is_closing((checked)) || (task)->ctx->source.session->handle == (checked)) + +/** @internal get key for tcp session + * @note kr_straddr() return pointer to static string + */ +#define tcpsess_key(addr) kr_straddr(addr) /* Forward decls */ static void qr_task_free(struct qr_task *task); -static int qr_task_step(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *packet); +static int qr_task_step(struct qr_task *task, + const struct sockaddr *packet_source, + knot_pkt_t *packet); +static int qr_task_send(struct qr_task *task, uv_handle_t *handle, + struct sockaddr *addr, knot_pkt_t *pkt); +static int qr_task_finalize(struct qr_task *task, int state); +static void qr_task_complete(struct qr_task *task); +static int worker_add_tcp_connected(struct worker_ctx *worker, + const struct sockaddr *addr, + struct session *session); +static int worker_del_tcp_connected(struct worker_ctx *worker, + const struct sockaddr *addr); +static struct session* worker_find_tcp_connected(struct worker_ctx *worker, + const struct sockaddr *srv); +static int worker_add_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr *addr, + struct session *session); +static int worker_del_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr *addr); +static struct session* worker_find_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr *srv); +static int session_add_waiting(struct session *session, struct qr_task *task); +static int session_del_waiting(struct session *session, struct qr_task *task); +static int session_add_tasks(struct session *session, struct qr_task *task); +static int session_del_tasks(struct session *session, struct qr_task *task); +static void session_close(struct session *session); +static void on_session_idle_timeout(uv_timer_t *timer); +static int timer_start(struct session *session, uv_timer_cb cb, + uint64_t timeout, uint64_t repeat); +static void on_tcp_connect_timeout(uv_timer_t *timer); +static void on_tcp_watchdog_timeout(uv_timer_t *timer); /** @internal Get singleton worker. */ static inline struct worker_ctx *get_worker(void) @@ -110,18 +187,19 @@ static uv_handle_t *ioreq_spawn(struct qr_task *task, int socktype, sa_family_t return NULL; } /* Create connection for iterative query */ - uv_handle_t *handle = (uv_handle_t *)req_borrow(task->worker); + struct worker_ctx *worker = task->ctx->worker; + uv_handle_t *handle = (uv_handle_t *)req_borrow(worker); if (!handle) { return NULL; } - io_create(task->worker->loop, handle, socktype); + io_create(worker->loop, handle, socktype); /* Bind to outgoing address, according to IP v4/v6. */ union inaddr *addr; if (family == AF_INET) { - addr = (union inaddr *)&task->worker->out_addr4; + addr = (union inaddr *)&worker->out_addr4; } else { - addr = (union inaddr *)&task->worker->out_addr6; + addr = (union inaddr *)&worker->out_addr6; } int ret = 0; if (addr->ip.sa_family != AF_UNSPEC) { @@ -137,47 +215,197 @@ static uv_handle_t *ioreq_spawn(struct qr_task *task, int socktype, sa_family_t struct session *session = handle->data; if (ret == 0) { session->outgoing = true; - ret = array_push(session->tasks, task); + ret = session_add_tasks(session, task); } if (ret < 0) { io_deinit(handle); - req_release(task->worker, (struct req *)handle); + req_release(worker, (struct req *)handle); return NULL; } - qr_task_ref(task); /* Connect or issue query datagram */ task->pending[task->pending_count] = handle; task->pending_count += 1; return handle; } -static void ioreq_on_close(uv_handle_t *handle) +static void on_session_close(uv_handle_t *handle) { struct worker_ctx *worker = get_worker(); - /* Handle-type events own a session, must close it. */ struct session *session = handle->data; - struct qr_task *task = session->tasks.at[0]; + if (!session->outgoing) { + assert(session->handle->type == UV_TCP); + } + bool free_handle = false; + if (!session->outgoing && session->handle->type == UV_TCP) { + free_handle = true; + } io_deinit(handle); - qr_task_unref(task); - req_release(worker, (struct req *)handle); + if (free_handle) { + free(handle); + } else { + req_release(worker, (struct req *)handle); + } +} + +static void on_session_timer_close(uv_handle_t *timer) +{ + struct session *session = timer->data; + uv_handle_t *handle = session->handle; + if (!uv_is_closing(handle)) { + uv_close(handle, on_session_close); + } } -static void ioreq_kill(uv_handle_t *req) +static void ioreq_kill_udp(uv_handle_t *req, struct qr_task *task) { assert(req); - if (!uv_is_closing(req)) { - uv_close(req, ioreq_on_close); + struct session *session = req->data; + assert(session->outgoing); + if (session->closing) { + return; + } + uv_timer_stop(&session->timeout); + session_del_tasks(session, task); + assert(session->tasks.len == 0); + session_close(session); +} + +static void ioreq_kill_tcp(uv_handle_t *req, struct qr_task *task) +{ + assert(req); + struct session *session = req->data; + assert(session->outgoing); + if (session->closing) { + return; + } + + session_del_waiting(session, task); + session_del_tasks(session, task); + + int res = 0; + + if (session->outgoing && session->peer.ip.sa_family != AF_UNSPEC && + session->tasks.len == 0 && session->waiting.len == 0 && + session->connected && !session->closing) { + assert(session->peer.ip.sa_family == AF_INET || + session->peer.ip.sa_family == AF_INET6); + /* This is outbound TCP connection which can be reused. + * Close it after timeout */ + uv_timer_t *timer = &session->timeout; + timer->data = session; + uv_timer_stop(timer); + res = uv_timer_start(timer, on_session_idle_timeout, + KR_CONN_RTT_MAX, 0); + } + + if (res != 0) { + /* if any errors, close the session immediately */ + session_close(session); } } -static void ioreq_killall(struct qr_task *task) +static void ioreq_kill_pending(struct qr_task *task) { - for (size_t i = 0; i < task->pending_count; ++i) { - ioreq_kill(task->pending[i]); + for (uint16_t i = 0; i < task->pending_count; ++i) { + if (task->pending[i]->type == UV_UDP) { + ioreq_kill_udp(task->pending[i], task); + } else if (task->pending[i]->type == UV_TCP) { + ioreq_kill_tcp(task->pending[i], task); + } else { + assert(false); + } } task->pending_count = 0; } +static void session_close(struct session *session) +{ + assert(session->tasks.len == 0 && session->waiting.len == 0); + + if (session->closing) { + return; + } + + if (session->buffering != NULL) { + qr_task_complete(session->buffering); + session->buffering = NULL; + } + + session->closing = true; + if (session->outgoing && + session->peer.ip.sa_family != AF_UNSPEC) { + struct worker_ctx *worker = get_worker(); + struct sockaddr *peer = &session->peer.ip; + worker_del_tcp_connected(worker, peer); + session->connected = false; + } + + if (!uv_is_closing((uv_handle_t *)&session->timeout)) { + uv_timer_stop(&session->timeout); + if (session->tls_client_ctx) { + tls_client_close(session->tls_client_ctx); + } + session->timeout.data = session; + uv_close((uv_handle_t *)&session->timeout, on_session_timer_close); + } +} + +static int session_add_waiting(struct session *session, struct qr_task *task) +{ + for (int i = 0; i < session->waiting.len; ++i) { + if (session->waiting.at[i] == task) { + return i; + } + } + int ret = array_push(session->waiting, task); + if (ret >= 0) { + qr_task_ref(task); + } + return ret; +} + +static int session_del_waiting(struct session *session, struct qr_task *task) +{ + int ret = kr_error(ENOENT); + for (int i = 0; i < session->waiting.len; ++i) { + if (session->waiting.at[i] == task) { + array_del(session->waiting, i); + qr_task_unref(task); + ret = kr_ok(); + break; + } + } + return ret; +} + +static int session_add_tasks(struct session *session, struct qr_task *task) +{ + for (int i = 0; i < session->tasks.len; ++i) { + if (session->tasks.at[i] == task) { + return i; + } + } + int ret = array_push(session->tasks, task); + if (ret >= 0) { + qr_task_ref(task); + } + return ret; +} + +static int session_del_tasks(struct session *session, struct qr_task *task) +{ + int ret = kr_error(ENOENT); + for (int i = 0; i < session->tasks.len; ++i) { + if (session->tasks.at[i] == task) { + array_del(session->tasks, i); + qr_task_unref(task); + ret = kr_ok(); + break; + } + } + return ret; +} + /** @cond This memory layout is internal to mempool.c, use only for debugging. */ #if defined(__SANITIZE_ADDRESS__) struct mempool_chunk { @@ -202,9 +430,9 @@ static void mp_poison(struct mempool *mp, bool poison) #endif /** @endcond */ +/** Get a mempool. (Recycle if possible.) */ static inline struct mempool *pool_borrow(struct worker_ctx *worker) { - /* Recycle available mempool if possible */ struct mempool *mp = NULL; if (worker->pool_mp.len > 0) { mp = array_tail(worker->pool_mp); @@ -216,9 +444,9 @@ static inline struct mempool *pool_borrow(struct worker_ctx *worker) return mp; } +/** Return a mempool. (Cache them up to some count.) */ static inline void pool_release(struct worker_ctx *worker, struct mempool *mp) { - /* Return mempool to ring or free it if it's full */ if (worker->pool_mp.len < MP_FREELIST_SIZE) { mp_flush(mp); array_push(worker->pool_mp, mp); @@ -235,115 +463,120 @@ static int subreq_key(char *dst, knot_pkt_t *pkt) return kr_rrkey(dst, knot_pkt_qname(pkt), knot_pkt_qtype(pkt), knot_pkt_qclass(pkt)); } -static struct qr_task *qr_task_create(struct worker_ctx *worker, uv_handle_t *handle, const struct sockaddr *addr) +/** Create and initialize a request_ctx (on a fresh mempool). + * + * handle and addr point to the source of the request, and they are NULL + * in case the request didn't come from network. + */ +static struct request_ctx *request_create(struct worker_ctx *worker, + uv_handle_t *handle, + const struct sockaddr *addr) { - /* How much can client handle? */ - struct engine *engine = worker->engine; - size_t pktbuf_max = KR_EDNS_PAYLOAD; - if (engine->resolver.opt_rr) { - pktbuf_max = MAX(knot_edns_get_payload(engine->resolver.opt_rr), pktbuf_max); - } - - /* Recycle available mempool if possible */ knot_mm_t pool = { .ctx = pool_borrow(worker), .alloc = (knot_mm_alloc_t) mp_alloc }; - /* Create resolution task */ - struct qr_task *task = mm_alloc(&pool, sizeof(*task)); - if (!task) { - mp_delete(pool.ctx); + /* Create request context */ + struct request_ctx *ctx = mm_alloc(&pool, sizeof(*ctx)); + if (!ctx) { + pool_release(worker, pool.ctx); return NULL; } - /* Create packet buffers for answer and subrequests */ - task->req.pool = pool; - knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &task->req.pool); - if (!pktbuf) { - mp_delete(pool.ctx); - return NULL; - } - pktbuf->size = 0; - task->req.answer = NULL; - task->pktbuf = pktbuf; - array_init(task->waiting); - task->addrlist = NULL; - task->pending_count = 0; - task->bytes_remaining = 0; - task->iter_count = 0; - task->timeouts = 0; - task->refs = 1; - task->finished = false; - task->leading = false; - task->worker = worker; - task->session = NULL; - task->source.handle = handle; - task->timeout = NULL; - task->on_complete = NULL; - task->req.qsource.key = NULL; - task->req.qsource.addr = NULL; - task->req.qsource.dst_addr = NULL; - task->req.qsource.packet = NULL; - task->req.qsource.opt = NULL; - task->req.qsource.size = 0; + memset(ctx, 0, sizeof(*ctx)); + + /* TODO Relocate pool to struct request */ + ctx->worker = worker; + array_init(ctx->tasks); + ctx->source.session = handle ? handle->data : NULL; + + struct kr_request *req = &ctx->req; + req->pool = pool; + /* Remember query source addr */ - if (addr) { + if (!addr || (addr->sa_family != AF_INET && addr->sa_family != AF_INET6)) { + ctx->source.addr.ip.sa_family = AF_UNSPEC; + } else { size_t addr_len = sizeof(struct sockaddr_in); if (addr->sa_family == AF_INET6) addr_len = sizeof(struct sockaddr_in6); - memcpy(&task->source.addr, addr, addr_len); - task->req.qsource.addr = (const struct sockaddr *)&task->source.addr; - } else { - task->source.addr.ip4.sin_family = AF_UNSPEC; + memcpy(&ctx->source.addr.ip, addr, addr_len); + ctx->req.qsource.addr = &ctx->source.addr.ip; } + + worker->stats.rconcurrent += 1; + + if (!handle) { + return ctx; + } + /* Remember the destination address. */ - if (handle) { - int addr_len = sizeof(task->source.dst_addr); - struct sockaddr *dst_addr = (struct sockaddr *)&task->source.dst_addr; - task->source.dst_addr.ip4.sin_family = AF_UNSPEC; - if (handle->type == UV_UDP) { - if (uv_udp_getsockname((uv_udp_t *)handle, dst_addr, &addr_len) == 0) { - task->req.qsource.dst_addr = dst_addr; - } - task->req.qsource.tcp = false; - } else if (handle->type == UV_TCP) { - if (uv_tcp_getsockname((uv_tcp_t *)handle, dst_addr, &addr_len) == 0) { - task->req.qsource.dst_addr = dst_addr; - } - task->req.qsource.tcp = true; + int addr_len = sizeof(ctx->source.dst_addr); + struct sockaddr *dst_addr = &ctx->source.dst_addr.ip; + ctx->source.dst_addr.ip.sa_family = AF_UNSPEC; + if (handle->type == UV_UDP) { + if (uv_udp_getsockname((uv_udp_t *)handle, dst_addr, &addr_len) == 0) { + req->qsource.dst_addr = dst_addr; + } + req->qsource.tcp = false; + } else if (handle->type == UV_TCP) { + if (uv_tcp_getsockname((uv_tcp_t *)handle, dst_addr, &addr_len) == 0) { + req->qsource.dst_addr = dst_addr; } + req->qsource.tcp = true; } - worker->stats.concurrent += 1; - return task; + + return ctx; } -/* This is called when the task refcount is zero, free memory. */ -static void qr_task_free(struct qr_task *task) +/** More initialization, related to the particular incoming query/packet. */ +static int request_start(struct request_ctx *ctx, knot_pkt_t *query) { - struct session *session = task->session; - if (session) { - /* Walk the session task list and remove itself. */ - for (size_t i = 0; i < session->tasks.len; ++i) { - if (session->tasks.at[i] == task) { - array_del(session->tasks, i); - break; - } - } - /* Start reading again if the session is throttled and - * the number of outgoing requests is below watermark. */ - uv_handle_t *handle = task->source.handle; - if (handle && session->tasks.len < task->worker->tcp_pipeline_max/2) { - if (!uv_is_closing(handle) && session->throttled) { - io_start_read(handle); - session->throttled = false; - } - } + assert(query && ctx); + size_t answer_max = KNOT_WIRE_MIN_PKTSIZE; + struct kr_request *req = &ctx->req; + + /* source.session can be empty if request was generated by kresd itself */ + if (!ctx->source.session || + ctx->source.session->handle->type == UV_TCP) { + answer_max = KNOT_WIRE_MAX_PKTSIZE; + } else if (knot_pkt_has_edns(query)) { /* EDNS */ + answer_max = MAX(knot_edns_get_payload(query->opt_rr), + KNOT_WIRE_MIN_PKTSIZE); } - /* Update stats */ - struct worker_ctx *worker = task->worker; - worker->stats.concurrent -= 1; + req->qsource.size = query->size; + + req->answer = knot_pkt_new(NULL, answer_max, &req->pool); + if (!req->answer) { + return kr_error(ENOMEM); + } + + /* Remember query source TSIG key */ + if (query->tsig_rr) { + req->qsource.key = knot_rrset_copy(query->tsig_rr, &req->pool); + } + + /* Remember query source EDNS data */ + if (query->opt_rr) { + req->qsource.opt = knot_rrset_copy(query->opt_rr, &req->pool); + } + /* Start resolution */ + struct worker_ctx *worker = ctx->worker; + struct engine *engine = worker->engine; + kr_resolve_begin(req, &engine->resolver, req->answer); + worker->stats.queries += 1; + /* Throttle outbound queries only when high pressure */ + if (worker->stats.concurrent < QUERY_RATE_THRESHOLD) { + req->options.NO_THROTTLE = true; + } + return kr_ok(); +} + +static void request_free(struct request_ctx *ctx) +{ + struct worker_ctx *worker = ctx->worker; /* Return mempool to ring or free it if it's full */ - pool_release(worker, task->req.pool.ctx); + pool_release(worker, ctx->req.pool.ctx); /* @note The 'task' is invalidated from now on. */ /* Decommit memory every once in a while */ static int mp_delete_count = 0; @@ -354,105 +587,202 @@ static void qr_task_free(struct qr_task *task) #endif mp_delete_count = 0; } + worker->stats.rconcurrent -= 1; } -static int qr_task_start(struct qr_task *task, knot_pkt_t *query) +static int request_add_tasks(struct request_ctx *ctx, struct qr_task *task) { - assert(task && query); - size_t answer_max = KNOT_WIRE_MIN_PKTSIZE; - if (!task->source.handle || task->source.handle->type == UV_TCP) { - answer_max = KNOT_WIRE_MAX_PKTSIZE; - } else if (knot_pkt_has_edns(query)) { /* EDNS */ - answer_max = MAX(knot_edns_get_payload(query->opt_rr), KNOT_WIRE_MIN_PKTSIZE); + for (int i = 0; i < ctx->tasks.len; ++i) { + if (ctx->tasks.at[i] == task) { + return i; + } + } + int ret = array_push(ctx->tasks, task); + if (ret >= 0) { + qr_task_ref(task); } + return ret; +} - /* Remember query packet size */ - task->req.qsource.size = query->size; +static int request_del_tasks(struct request_ctx *ctx, struct qr_task *task) +{ + int ret = kr_error(ENOENT); + for (int i = 0; i < ctx->tasks.len; ++i) { + if (ctx->tasks.at[i] == task) { + array_del(ctx->tasks, i); + qr_task_unref(task); + ret = kr_ok(); + break; + } + } + return ret; +} - knot_pkt_t *answer = knot_pkt_new(NULL, answer_max, &task->req.pool); - if (!answer) { - return kr_error(ENOMEM); + +static struct qr_task *qr_task_create(struct request_ctx *ctx) +{ + /* How much can client handle? */ + struct engine *engine = ctx->worker->engine; + size_t pktbuf_max = KR_EDNS_PAYLOAD; + if (engine->resolver.opt_rr) { + pktbuf_max = MAX(knot_edns_get_payload(engine->resolver.opt_rr), + pktbuf_max); } - task->req.answer = answer; - /* Remember query source TSIG key */ - if (query->tsig_rr) { - task->req.qsource.key = knot_rrset_copy(query->tsig_rr, &task->req.pool); + /* Create resolution task */ + struct qr_task *task = mm_alloc(&ctx->req.pool, sizeof(*task)); + if (!task) { + return NULL; } + memset(task, 0, sizeof(*task)); /* avoid accidentally unitialized fields */ - /* Remember query source EDNS data */ - if (query->opt_rr) { - task->req.qsource.opt = knot_rrset_copy(query->opt_rr, &task->req.pool); + /* Create packet buffers for answer and subrequests */ + knot_pkt_t *pktbuf = knot_pkt_new(NULL, pktbuf_max, &ctx->req.pool); + if (!pktbuf) { + mm_free(&ctx->req.pool, task); + return NULL; } + pktbuf->size = 0; - /* Start resolution */ - struct worker_ctx *worker = task->worker; - struct engine *engine = worker->engine; - kr_resolve_begin(&task->req, &engine->resolver, answer); - worker->stats.queries += 1; - /* Throttle outbound queries only when high pressure */ - if (worker->stats.concurrent < QUERY_RATE_THRESHOLD) { - task->req.options.NO_THROTTLE = true; + task->ctx = ctx; + task->pktbuf = pktbuf; + array_init(task->waiting); + task->refs = 0; + int ret = request_add_tasks(ctx, task); + if (ret < 0) { + mm_free(&ctx->req.pool, task); + mm_free(&ctx->req.pool, pktbuf); + return NULL; } - return 0; + ctx->worker->stats.concurrent += 1; + return task; } -/*@ Register qr_task within session. */ +/* This is called when the task refcount is zero, free memory. */ +static void qr_task_free(struct qr_task *task) +{ + struct request_ctx *ctx = task->ctx; + + assert(ctx); + + /* Process outbound session. */ + struct session *source_session = ctx->source.session; + struct worker_ctx *worker = ctx->worker; + + /* Process source session. */ + if (source_session) { + /* Walk the session task list and remove itself. */ + session_del_tasks(source_session, task); + /* Start reading again if the session is throttled and + * the number of outgoing requests is below watermark. */ + uv_handle_t *handle = source_session->handle; + if (handle && source_session->tasks.len < worker->tcp_pipeline_max/2) { + if (!uv_is_closing(handle) && source_session->throttled) { + io_start_read(handle); + source_session->throttled = false; + } + } + } + + if (ctx->tasks.len == 0) { + array_clear(ctx->tasks); + request_free(ctx); + } + + /* Update stats */ + worker->stats.concurrent -= 1; +} + +/*@ Register new qr_task within session. */ static int qr_task_register(struct qr_task *task, struct session *session) { + assert(session->outgoing == false); + int ret = array_reserve(session->tasks, session->tasks.len + 1); if (ret != 0) { return kr_error(ENOMEM); } - array_push(session->tasks, task); - task->session = session; + + session_add_tasks(session, task); + + struct request_ctx *ctx = task->ctx; + assert(ctx && (ctx->source.session == NULL || ctx->source.session == session)); + ctx->source.session = session; /* Soft-limit on parallel queries, there is no "slow down" RCODE * that we could use to signalize to client, but we can stop reading, * an in effect shrink TCP window size. To get more precise throttling, * we would need to copy remainder of the unread buffer and reassemble * when resuming reading. This is NYI. */ - if (session->tasks.len >= task->worker->tcp_pipeline_max) { - uv_handle_t *handle = task->source.handle; + if (session->tasks.len >= task->ctx->worker->tcp_pipeline_max) { + uv_handle_t *handle = session->handle; if (handle && !session->throttled && !uv_is_closing(handle)) { io_stop_read(handle); session->throttled = true; } } + return 0; } static void qr_task_complete(struct qr_task *task) { - struct worker_ctx *worker = task->worker; + struct request_ctx *ctx = task->ctx; + struct worker_ctx *worker = ctx->worker; /* Kill pending I/O requests */ - ioreq_killall(task); + ioreq_kill_pending(task); assert(task->waiting.len == 0); assert(task->leading == false); /* Run the completion callback. */ if (task->on_complete) { - task->on_complete(worker, &task->req, task->baton); + task->on_complete(worker, &ctx->req, task->baton); } /* Release primary reference to task. */ - qr_task_unref(task); + request_del_tasks(ctx, task); } /* This is called when we send subrequest / answer */ static int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status) { - if (!task->finished) { - if (status == 0 && handle) { - /* For TCP we can be sure there will be no retransmit, so we flush - * the packet buffer so it can be reused again for reassembly. */ - if (handle->type == UV_TCP) { - knot_pkt_t *pktbuf = task->pktbuf; - knot_pkt_clear(pktbuf); - pktbuf->size = 0; + if (task->finished) { + assert(task->leading == false); + qr_task_complete(task); + if (!handle || handle->type != UV_TCP) { + return status; + } + struct session* session = handle->data; + if (!session->outgoing || + session->waiting.len == 0) { + return status; + } + } + + + if (status == 0 && handle) { + struct session* session = handle->data; + if (handle->type == UV_TCP && session->outgoing && + session->waiting.len > 0) { + session_del_waiting(session, task); + if (session->waiting.len > 0) { + struct qr_task *t = session->waiting.at[0]; + int ret = qr_task_send(t, (uv_handle_t *)handle, + &session->peer.ip, t->pktbuf); + if (ret != kr_ok()) { + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + array_del(session->waiting, 0); + qr_task_finalize(task, KR_STATE_FAIL); + } + while (session->tasks.len > 0) { + struct qr_task *task = session->tasks.at[0]; + array_del(session->tasks, 0); + qr_task_finalize(task, KR_STATE_FAIL); + } + session_close(session); + return status; + } } - io_start_read(handle); /* Start reading new query */ } - } else { - assert(task->timeout == NULL); - qr_task_complete(task); + io_start_read(handle); /* Start reading new query */ } return status; } @@ -488,12 +818,26 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle, struct sockad /* Synchronous push to TLS context, bypassing event loop. */ struct session *session = handle->data; if (session->has_tls) { - int ret = tls_push(task, handle, pkt); + struct kr_request *req = &task->ctx->req; + int ret = kr_ok(); + if (!session->outgoing) { + ret = tls_push(task, handle, pkt); + } else { + ret = kr_resolve_checkout(req, NULL, addr, + SOCK_STREAM, pkt); + if (ret != kr_ok()) { + return ret; + } + ret = tls_client_push(task, handle, pkt); + } return qr_task_on_send(task, handle, ret); } int ret = 0; - struct req *send_req = req_borrow(task->worker); + struct request_ctx *ctx = task->ctx; + struct worker_ctx *worker = ctx->worker; + struct kr_request *req = &ctx->req; + struct req *send_req = req_borrow(worker); if (!send_req) { return qr_task_on_send(task, handle, kr_error(ENOMEM)); } @@ -510,11 +854,11 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle, struct sockad * @note -- A solution might be opening a separate socket and * trying to obtain the IP address from it. */ - ret = kr_resolve_checkout(&task->req, NULL, addr, + ret = kr_resolve_checkout(req, NULL, addr, handle->type == UV_UDP ? SOCK_DGRAM : SOCK_STREAM, pkt); if (ret != kr_ok()) { - req_release(task->worker, send_req); + req_release(worker, send_req); return ret; } } @@ -534,59 +878,228 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle, struct sockad } if (ret == 0) { qr_task_ref(task); /* Pending ioreq on current task */ + if (worker->too_many_open && + worker->stats.rconcurrent < + worker->rconcurrent_highwatermark - (worker->rconcurrent_highwatermark / 4)) { + worker->too_many_open = false; + } } else { - req_release(task->worker, send_req); + req_release(worker, send_req); + if (ret == UV_EMFILE) { + worker->too_many_open = true; + worker->rconcurrent_highwatermark = worker->stats.rconcurrent; + } } /* Update statistics */ - if (handle != task->source.handle && addr) { + if (ctx->source.session && + handle != ctx->source.session->handle && + addr) { if (handle->type == UV_UDP) - task->worker->stats.udp += 1; + worker->stats.udp += 1; else - task->worker->stats.tcp += 1; + worker->stats.tcp += 1; if (addr->sa_family == AF_INET6) - task->worker->stats.ipv6 += 1; + worker->stats.ipv6 += 1; else - task->worker->stats.ipv4 += 1; + worker->stats.ipv4 += 1; } return ret; } +static int session_next_waiting_send(struct session *session) +{ + union inaddr *peer = &session->peer; + int ret = kr_ok(); + if (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + ret = qr_task_send(task, session->handle, &peer->ip, task->pktbuf); + if (ret == kr_ok()) { + session->timeout.data = session; + timer_start(session, on_tcp_watchdog_timeout, MAX_TCP_INACTIVITY, 0); + } + } + return ret; +} + +static int session_tls_hs_cb(struct session *session, int status) +{ + VERBOSE_MSG(NULL, "=> server: '%s' TLS handshake has %s\n", + kr_straddr(&session->peer.ip), status ? "failed" : "completed"); + if (status == 0) { + int ret = session_next_waiting_send(session); + if (ret == kr_ok()) { + struct worker_ctx *worker = get_worker(); + union inaddr *peer = &session->peer; + int ret = worker_add_tcp_connected(worker, &peer->ip, session); + assert(ret == 0); + } + } + return kr_ok(); +} + static void on_connect(uv_connect_t *req, int status) { struct worker_ctx *worker = get_worker(); - struct qr_task *task = req->data; uv_stream_t *handle = req->handle; - if (qr_valid_handle(task, (uv_handle_t *)req->handle)) { - if (status == 0) { - struct sockaddr_storage addr; - int addr_len = sizeof(addr); - uv_tcp_getpeername((uv_tcp_t *)handle, (struct sockaddr *)&addr, &addr_len); - qr_task_send(task, (uv_handle_t *)handle, (struct sockaddr *)&addr, task->pktbuf); - } else { - qr_task_step(task, task->addrlist, NULL); + struct session *session = handle->data; + + union inaddr *peer = &session->peer; + uv_timer_stop((uv_timer_t *)&session->timeout); + + if (status == UV_ECANCELED) { + worker_del_tcp_waiting(worker, &peer->ip); + assert(session->closing && session->waiting.len == 0 && session->tasks.len == 0); + req_release(worker, (struct req *)req); + return; + } + + if (session->closing) { + worker_del_tcp_waiting(worker, &peer->ip); + assert(session->waiting.len == 0 && session->tasks.len == 0); + req_release(worker, (struct req *)req); + return; + } + + if (worker_del_tcp_waiting(worker, &peer->ip) != 0) { + /* session isn't in list of waiting queries, * + * something gone wrong */ + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + session_del_tasks(session, task); + array_del(session->waiting, 0); + qr_task_finalize(task, KR_STATE_FAIL); + qr_task_unref(task); } + assert(session->tasks.len == 0); + req_release(worker, (struct req *)req); + session_close(session); + return; } - qr_task_unref(task); + + if (status != 0) { + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + session_del_tasks(session, task); + array_del(session->waiting, 0); + qr_task_finalize(task, KR_STATE_FAIL); + qr_task_unref(task); + } + assert(session->tasks.len == 0); + req_release(worker, (struct req *)req); + session_close(session); + return; + } + + session->connected = true; + session->handle = (uv_handle_t *)handle; + + int ret = kr_ok(); + if (session->has_tls) { + ret = tls_client_connect_start(session->tls_client_ctx, + session, session_tls_hs_cb); + if (ret == kr_error(EAGAIN)) { + req_release(worker, (struct req *)req); + io_start_read(session->handle); + return; + } + } + + if (ret == kr_ok()) { + ret = session_next_waiting_send(session); + if (ret == kr_ok()) { + worker_add_tcp_connected(worker, &session->peer.ip, session); + req_release(worker, (struct req *)req); + return; + } + } + + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + session_del_tasks(session, task); + array_del(session->waiting, 0); + qr_task_finalize(task, KR_STATE_FAIL); + qr_task_unref(task); + } + + assert(session->tasks.len == 0); + req_release(worker, (struct req *)req); + session_close(session); } -static void on_timer_close(uv_handle_t *handle) +static void on_tcp_connect_timeout(uv_timer_t *timer) { - struct qr_task *task = handle->data; - req_release(task->worker, (struct req *)handle); - qr_task_unref(task); + struct session *session = timer->data; + + uv_timer_stop(timer); + struct worker_ctx *worker = get_worker(); + + assert (session->waiting.len == session->tasks.len); + + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + struct request_ctx *ctx = task->ctx; + task->timeouts += 1; + worker->stats.timeout += 1; + session_del_tasks(session, task); + array_del(session->waiting, 0); + qr_task_unref(task); + assert(task->refs == 1); + qr_task_finalize(task, KR_STATE_FAIL); + } + + assert (session->tasks.len == 0); + session_close(session); +} + +static void on_tcp_watchdog_timeout(uv_timer_t *timer) +{ + struct session *session = timer->data; + + assert(session->outgoing); + uv_timer_stop(timer); + struct worker_ctx *worker = get_worker(); + + worker_del_tcp_connected(worker, &session->peer.ip); + + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + task->timeouts += 1; + worker->stats.timeout += 1; + array_del(session->waiting, 0); + qr_task_unref(task); + session_del_tasks(session, task); + qr_task_finalize(task, KR_STATE_FAIL); + } + + while (session->tasks.len > 0) { + struct qr_task *task = session->tasks.at[0]; + task->timeouts += 1; + worker->stats.timeout += 1; + assert(task->refs > 1); + array_del(session->tasks, 0); + qr_task_unref(task); + qr_task_finalize(task, KR_STATE_FAIL); + } + + session_close(session); } /* This is called when I/O timeouts */ -static void on_timeout(uv_timer_t *req) +static void on_udp_timeout(uv_timer_t *timer) { - struct qr_task *task = req->data; + struct session *session = timer->data; + uv_timer_stop(timer); + + assert(session->tasks.len == 1); + assert(session->waiting.len == 0); /* Penalize all tried nameservers with a timeout. */ - struct worker_ctx *worker = task->worker; + struct qr_task *task = session->tasks.at[0]; + struct worker_ctx *worker = task->ctx->worker; if (task->leading && task->pending_count > 0) { - struct kr_query *qry = array_tail(task->req.rplan.pending); + struct kr_query *qry = array_tail(task->ctx->req.rplan.pending); struct sockaddr_in6 *addrlist = (struct sockaddr_in6 *)task->addrlist; for (uint16_t i = 0; i < MIN(task->pending_count, task->addrlist_count); ++i) { struct sockaddr *choice = (struct sockaddr *)(&addrlist[i]); @@ -599,94 +1112,91 @@ static void on_timeout(uv_timer_t *req) worker->engine->resolver.cache_rtt, KR_NS_UPDATE); } } - /* Release timer handle */ - task->timeout = NULL; - uv_close((uv_handle_t *)req, on_timer_close); /* Return borrowed task here */ - /* Interrupt current pending request. */ task->timeouts += 1; worker->stats.timeout += 1; qr_task_step(task, NULL, NULL); } -static bool retransmit(struct qr_task *task) +static void on_session_idle_timeout(uv_timer_t *timer) +{ + struct session *s = timer->data; + assert(s && s->outgoing); + uv_timer_stop(timer); + if (s->closing) { + return; + } + /* session was not in use during timer timeout + * remove it from connection list and close + */ + assert(s->tasks.len == 0 && s->waiting.len == 0); + session_close(s); +} + +static uv_handle_t *retransmit(struct qr_task *task) { + uv_handle_t *ret = NULL; if (task && task->addrlist && task->addrlist_count > 0) { struct sockaddr_in6 *choice = &((struct sockaddr_in6 *)task->addrlist)[task->addrlist_turn]; - uv_handle_t *subreq = ioreq_spawn(task, SOCK_DGRAM, choice->sin6_family); - if (subreq) { /* Create connection for iterative query */ - if (qr_task_send(task, subreq, (struct sockaddr *)choice, task->pktbuf) == 0) { - task->addrlist_turn = (task->addrlist_turn + 1) % task->addrlist_count; /* Round robin */ - return true; - } + ret = ioreq_spawn(task, SOCK_DGRAM, choice->sin6_family); + if (ret && + qr_task_send(task, ret, (struct sockaddr *)choice, + task->pktbuf) == 0) { + task->addrlist_turn = (task->addrlist_turn + 1) % + task->addrlist_count; /* Round robin */ } } - return false; + return ret; } static void on_retransmit(uv_timer_t *req) { - struct qr_task *task = req->data; - assert(task->finished == false); - assert(task->timeout != NULL); + struct session *session = req->data; + assert(session->tasks.len == 1); uv_timer_stop(req); - if (!retransmit(req->data)) { + struct qr_task *task = session->tasks.at[0]; + if (retransmit(task) == NULL) { /* Not possible to spawn request, start timeout timer with remaining deadline. */ uint64_t timeout = KR_CONN_RTT_MAX - task->pending_count * KR_CONN_RETRY; - uv_timer_start(req, on_timeout, timeout, 0); + uv_timer_start(req, on_udp_timeout, timeout, 0); } else { uv_timer_start(req, on_retransmit, KR_CONN_RETRY, 0); } } -static int timer_start(struct qr_task *task, uv_timer_cb cb, uint64_t timeout, uint64_t repeat) +static int timer_start(struct session *session, uv_timer_cb cb, + uint64_t timeout, uint64_t repeat) { - assert(task->timeout == NULL); - struct worker_ctx *worker = task->worker; - uv_timer_t *timer = (uv_timer_t *)req_borrow(worker); - if (!timer) { - return kr_error(ENOMEM); - } - uv_timer_init(worker->loop, timer); + uv_timer_t *timer = (uv_timer_t *)&session->timeout; + assert(timer->data == session); int ret = uv_timer_start(timer, cb, timeout, repeat); if (ret != 0) { uv_timer_stop(timer); - req_release(worker, (struct req *)timer); return kr_error(ENOMEM); } - timer->data = task; - qr_task_ref(task); - task->timeout = timer; return 0; } static void subreq_finalize(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *pkt) { /* Close pending timer */ - if (task->timeout) { - /* Timer was running so it holds reference to task, make sure the timer event - * never fires and release the reference on timer close instead. */ - uv_timer_stop(task->timeout); - uv_close((uv_handle_t *)task->timeout, on_timer_close); - task->timeout = NULL; - } - ioreq_killall(task); + ioreq_kill_pending(task); /* Clear from outgoing table. */ if (!task->leading) return; char key[KR_RRKEY_LEN]; int ret = subreq_key(key, task->pktbuf); if (ret > 0) { - assert(map_get(&task->worker->outgoing, key) == task); - map_del(&task->worker->outgoing, key); + assert(map_get(&task->ctx->worker->outgoing, key) == task); + map_del(&task->ctx->worker->outgoing, key); } /* Notify waiting tasks. */ - struct kr_query *leader_qry = array_tail(task->req.rplan.pending); + struct kr_query *leader_qry = array_tail(task->ctx->req.rplan.pending); for (size_t i = task->waiting.len; i > 0; i--) { struct qr_task *follower = task->waiting.at[i - 1]; /* Reuse MSGID and 0x20 secret */ - if (follower->req.rplan.pending.len > 0) { - struct kr_query *qry = array_tail(follower->req.rplan.pending); + if (follower->ctx->req.rplan.pending.len > 0) { + struct kr_query *qry = array_tail(follower->ctx->req.rplan.pending); qry->id = leader_qry->id; qry->secret = leader_qry->secret; leader_qry->secret = 0; /* Next will be already decoded */ @@ -703,8 +1213,8 @@ static void subreq_lead(struct qr_task *task) assert(task); char key[KR_RRKEY_LEN]; if (subreq_key(key, task->pktbuf) > 0) { - assert(map_contains(&task->worker->outgoing, key) == false); - map_set(&task->worker->outgoing, key, task); + assert(map_contains(&task->ctx->worker->outgoing, key) == false); + map_set(&task->ctx->worker->outgoing, key, task); task->leading = true; } } @@ -714,10 +1224,11 @@ static bool subreq_enqueue(struct qr_task *task) assert(task); char key[KR_RRKEY_LEN]; if (subreq_key(key, task->pktbuf) > 0) { - struct qr_task *leader = map_get(&task->worker->outgoing, key); + struct qr_task *leader = map_get(&task->ctx->worker->outgoing, key); if (leader) { /* Enqueue itself to leader for this subrequest. */ - int ret = array_reserve_mm(leader->waiting, leader->waiting.len + 1, kr_memreserve, &leader->req.pool); + int ret = array_reserve_mm(leader->waiting, leader->waiting.len + 1, + kr_memreserve, &leader->ctx->req.pool); if (ret == 0) { array_push(leader->waiting, task); qr_task_ref(task); @@ -731,31 +1242,59 @@ static bool subreq_enqueue(struct qr_task *task) static int qr_task_finalize(struct qr_task *task, int state) { assert(task && task->leading == false); - kr_resolve_finish(&task->req, state); + struct request_ctx *ctx = task->ctx; + kr_resolve_finish(&ctx->req, state); task->finished = true; /* Send back answer */ - (void) qr_task_send(task, task->source.handle, (struct sockaddr *)&task->source.addr, task->req.answer); + if (ctx->source.session != NULL) { + (void) qr_task_send(task, ctx->source.session->handle, + (struct sockaddr *)&ctx->source.addr, + ctx->req.answer); + } else { + (void) qr_task_on_send(task, NULL, kr_error(EIO)); + } return state == KR_STATE_DONE ? 0 : kr_error(EIO); } -static int qr_task_step(struct qr_task *task, const struct sockaddr *packet_source, knot_pkt_t *packet) +static int qr_task_step(struct qr_task *task, + const struct sockaddr *packet_source, knot_pkt_t *packet) { /* No more steps after we're finished. */ if (!task || task->finished) { return kr_error(ESTALE); } + + /* Close pending I/O requests */ subreq_finalize(task, packet_source, packet); /* Consume input and produce next query */ + struct request_ctx *ctx = task->ctx; + struct kr_request *req = &ctx->req; + struct worker_ctx *worker = ctx->worker; int sock_type = -1; task->addrlist = NULL; task->addrlist_count = 0; task->addrlist_turn = 0; - task->req.has_tls = (task->session && task->session->has_tls); - int state = kr_resolve_consume(&task->req, packet_source, packet); + req->has_tls = (ctx->source.session && ctx->source.session->has_tls); + + if (worker->too_many_open) { + struct kr_rplan *rplan = &req->rplan; + if (worker->stats.rconcurrent < + worker->rconcurrent_highwatermark - (worker->rconcurrent_highwatermark / 4)) { + worker->too_many_open = false; + } else if (packet && kr_rplan_empty(rplan)) { + /* new query; TODO - make this detection more obvious */ + kr_resolve_consume(req, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + } + + int state = kr_resolve_consume(req, packet_source, packet); while (state == KR_STATE_PRODUCE) { - state = kr_resolve_produce(&task->req, &task->addrlist, &sock_type, task->pktbuf); - if (unlikely(++task->iter_count > KR_ITER_LIMIT || task->timeouts >= KR_TIMEOUT_LIMIT)) { + state = kr_resolve_produce(req, &task->addrlist, + &sock_type, task->pktbuf); + if (unlikely(++task->iter_count > KR_ITER_LIMIT || + task->timeouts >= KR_TIMEOUT_LIMIT)) { return qr_task_finalize(task, KR_STATE_FAIL); } } @@ -781,55 +1320,189 @@ static int qr_task_step(struct qr_task *task, const struct sockaddr *packet_sour if (subreq_enqueue(task)) { return kr_ok(); /* Will be notified when outgoing query finishes. */ } - /* Check current query NSLIST */ - struct kr_query *qry = array_tail(task->req.rplan.pending); /* Start transmitting */ - if (retransmit(task)) { - assert(qry != NULL); - /* Retransmit at default interval, or more frequently if the mean - * RTT of the server is better. If the server is glued, use default rate. */ - size_t timeout = qry->ns.score; - if (timeout > KR_NS_GLUED) { - /* We don't have information about variance in RTT, expect +10ms */ - timeout = MIN(qry->ns.score + 10, KR_CONN_RETRY); - } else { - timeout = KR_CONN_RETRY; - } - ret = timer_start(task, on_retransmit, timeout, 0); - } else { + uv_handle_t *handle = retransmit(task); + if (handle == NULL) { return qr_task_step(task, NULL, NULL); } + /* Check current query NSLIST */ + struct kr_query *qry = array_tail(req->rplan.pending); + assert(qry != NULL); + /* Retransmit at default interval, or more frequently if the mean + * RTT of the server is better. If the server is glued, use default rate. */ + size_t timeout = qry->ns.score; + if (timeout > KR_NS_GLUED) { + /* We don't have information about variance in RTT, expect +10ms */ + timeout = MIN(qry->ns.score + 10, KR_CONN_RETRY); + } else { + timeout = KR_CONN_RETRY; + } /* Announce and start subrequest. * @note Only UDP can lead I/O as it doesn't touch 'task->pktbuf' for reassembly. */ subreq_lead(task); - } else { - uv_connect_t *conn = (uv_connect_t *)req_borrow(task->worker); - if (!conn) { - return qr_task_step(task, NULL, NULL); + struct session *session = handle->data; + ret = timer_start(session, on_retransmit, timeout, 0); + /* Start next step with timeout, fatal if can't start a timer. */ + if (ret != 0) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); } + } else { + assert (sock_type == SOCK_STREAM); const struct sockaddr *addr = packet_source ? packet_source : task->addrlist; - uv_handle_t *client = ioreq_spawn(task, sock_type, addr->sa_family); - if (!client) { - req_release(task->worker, (struct req *)conn); - return qr_task_step(task, NULL, NULL); - } - conn->data = task; - if (uv_tcp_connect(conn, (uv_tcp_t *)client, addr , on_connect) != 0) { - req_release(task->worker, (struct req *)conn); - return qr_task_step(task, NULL, NULL); + if (addr->sa_family == AF_UNSPEC) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); } - qr_task_ref(task); /* Connect request borrows task */ - ret = timer_start(task, on_timeout, KR_CONN_RTT_MAX, 0); - } + struct session* session = NULL; + if ((session = worker_find_tcp_waiting(ctx->worker, addr)) != NULL) { + if (session->closing) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + /* There are waiting tasks. + * It means that connection establishing or data sending + * is coming right now. */ + /* Task will be notified in on_connect() or qr_task_on_send(). */ + ret = session_add_waiting(session, task); + if (ret < 0) { + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + ret = session_add_tasks(session, task); + if (ret < 0) { + session_del_waiting(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + } else if ((session = worker_find_tcp_connected(ctx->worker, addr)) != NULL) { + /* Connection has been already established */ + assert(session->outgoing); + if (session->closing) { + session_del_tasks(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } - /* Start next step with timeout, fatal if can't start a timer. */ - if (ret != 0) { - subreq_finalize(task, packet_source, packet); - return qr_task_finalize(task, KR_STATE_FAIL); + if (session->tasks.len >= worker->tcp_pipeline_max) { + session_del_tasks(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + + /* will be removed in qr_task_on_send() */ + ret = session_add_waiting(session, task); + if (ret < 0) { + session_del_tasks(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + ret = session_add_tasks(session, task); + if (ret < 0) { + session_del_waiting(session, task); + session_del_tasks(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + if (session->waiting.len == 1) { + ret = qr_task_send(task, session->handle, + &session->peer.ip, task->pktbuf); + if (ret < 0) { + session_del_waiting(session, task); + session_del_tasks(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + ret = timer_start(session, on_tcp_watchdog_timeout, + KR_CONN_RTT_MAX, 0); + if (ret < 0) { + assert(false); + session_del_waiting(session, task); + session_del_tasks(session, task); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + } + task->pending[task->pending_count] = session->handle; + task->pending_count += 1; + } else { + /* Make connection */ + uv_connect_t *conn = (uv_connect_t *)req_borrow(ctx->worker); + if (!conn) { + return qr_task_step(task, NULL, NULL); + } + uv_handle_t *client = ioreq_spawn(task, sock_type, + addr->sa_family); + if (!client) { + req_release(ctx->worker, (struct req *)conn); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + session = client->data; + ret = worker_add_tcp_waiting(ctx->worker, addr, session); + if (ret < 0) { + session_del_tasks(session, task); + req_release(ctx->worker, (struct req *)conn); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + /* will be removed in qr_task_on_send() */ + ret = session_add_waiting(session, task); + if (ret < 0) { + session_del_tasks(session, task); + worker_del_tcp_waiting(ctx->worker, addr); + req_release(ctx->worker, (struct req *)conn); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + + /* Check if there must be TLS */ + struct engine *engine = ctx->worker->engine; + struct network *net = &engine->net; + const char *key = tcpsess_key(addr); + struct tls_client_paramlist_entry *entry = map_get(&net->tls_client_params, key); + if (entry) { + assert(session->tls_client_ctx == NULL); + struct tls_client_ctx_t *tls_ctx = tls_client_ctx_new(entry); + if (!tls_ctx) { + session_del_tasks(session, task); + session_del_waiting(session, task); + worker_del_tcp_waiting(ctx->worker, addr); + req_release(ctx->worker, (struct req *)conn); + return qr_task_step(task, NULL, NULL); + } + tls_client_ctx_set_params(tls_ctx, entry); + session->tls_client_ctx = tls_ctx; + session->has_tls = true; + } + + conn->data = session; + memcpy(&session->peer, addr, sizeof(session->peer)); + + ret = timer_start(session, on_tcp_connect_timeout, + KR_CONN_RTT_MAX, 0); + if (ret != 0) { + session_del_tasks(session, task); + session_del_waiting(session, task); + worker_del_tcp_waiting(ctx->worker, addr); + req_release(ctx->worker, (struct req *)conn); + subreq_finalize(task, packet_source, packet); + return qr_task_finalize(task, KR_STATE_FAIL); + } + + if (uv_tcp_connect(conn, (uv_tcp_t *)client, + addr , on_connect) != 0) { + session_del_tasks(session, task); + session_del_waiting(session, task); + worker_del_tcp_waiting(ctx->worker, addr); + req_release(ctx->worker, (struct req *)conn); + return qr_task_step(task, NULL, NULL); + } + } } - return 0; + return kr_ok(); } static int parse_packet(knot_pkt_t *query) @@ -852,47 +1525,168 @@ static int parse_packet(knot_pkt_t *query) return kr_ok(); } -int worker_submit(struct worker_ctx *worker, uv_handle_t *handle, knot_pkt_t *msg, const struct sockaddr* addr) +static struct qr_task* find_task(const struct session *session, uint16_t msg_id) { - if (!worker || !handle) { + struct qr_task *ret = NULL; + const qr_tasklist_t *tasklist = &session->tasks; + for (size_t i = 0; i < tasklist->len; ++i) { + struct qr_task *task = tasklist->at[i]; + uint16_t task_msg_id = knot_wire_get_id(task->pktbuf->wire); + if (task_msg_id == msg_id) { + ret = task; + break; + } + } + return ret; +} + + +int worker_submit(struct worker_ctx *worker, uv_handle_t *handle, + knot_pkt_t *msg, const struct sockaddr* addr) +{ + bool OK = worker && handle && handle->data; + if (!OK) { + assert(false); return kr_error(EINVAL); } struct session *session = handle->data; - assert(session); /* Parse packet */ int ret = parse_packet(msg); - /* Start new task on listening sockets, or resume if this is subrequest */ + /* Start new task on listening sockets, + * or resume if this is subrequest */ struct qr_task *task = NULL; - if (!session->outgoing) { + if (!session->outgoing) { /* request from a client */ /* Ignore badly formed queries or responses. */ if (!msg || ret != 0 || knot_wire_get_qr(msg->wire)) { if (msg) worker->stats.dropped += 1; - return kr_error(EINVAL); /* Ignore. */ + return kr_error(EILSEQ); } - task = qr_task_create(worker, handle, addr); - if (!task) { + struct request_ctx *ctx = request_create(worker, handle, addr); + if (!ctx) { return kr_error(ENOMEM); } - ret = qr_task_start(task, msg); + + ret = request_start(ctx, msg); if (ret != 0) { - qr_task_free(task); + request_free(ctx); return kr_error(ENOMEM); } - } else { - task = session->tasks.len > 0 ? array_tail(session->tasks) : NULL; + + task = qr_task_create(ctx); + if (!task) { + request_free(ctx); + return kr_error(ENOMEM); + } + } else if (msg) { /* response from upstream */ + task = find_task(session, knot_wire_get_id(msg->wire)); } /* Consume input and produce next message */ - return qr_task_step(task, addr, msg); + return qr_task_step(task, NULL, msg); +} + +static int map_add_tcp_session(map_t *map, const struct sockaddr* addr, + struct session *session) +{ + assert(map && addr); + const char *key = tcpsess_key(addr); + assert(key); + assert(map_contains(map, key) == 0); + int ret = map_set(map, key, session); + return ret ? kr_error(EINVAL) : kr_ok(); +} + +static int map_del_tcp_session(map_t *map, const struct sockaddr* addr) +{ + assert(map && addr); + const char *key = tcpsess_key(addr); + assert(key); + int ret = map_del(map, key); + return ret ? kr_error(ENOENT) : kr_ok(); +} + +static struct session* map_find_tcp_session(map_t *map, + const struct sockaddr *addr) +{ + assert(map && addr); + const char *key = tcpsess_key(addr); + assert(key); + struct session* ret = map_get(map, key); + return ret; +} + +static int worker_add_tcp_connected(struct worker_ctx *worker, + const struct sockaddr* addr, + struct session *session) +{ + assert(addr); + const char *key = tcpsess_key(addr); + assert(key); + tcp_connected += 1; + assert(map_contains(&worker->tcp_connected, key) == 0); + return map_add_tcp_session(&worker->tcp_connected, addr, session); +} + +static int worker_del_tcp_connected(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + assert(addr); + const char *key = tcpsess_key(addr); + assert(key); + int ret = map_del_tcp_session(&worker->tcp_connected, addr); + if (ret == 0) { + tcp_connected -= 1; + } + return ret; +} + +static struct session* worker_find_tcp_connected(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + return map_find_tcp_session(&worker->tcp_connected, addr); +} + +static int worker_add_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr, + struct session *session) +{ + assert(addr); + const char *key = tcpsess_key(addr); + assert(key); + assert(map_contains(&worker->tcp_waiting, key) == 0); + int ret = map_add_tcp_session(&worker->tcp_waiting, addr, session); + if (ret == 0) { + tcp_waiting += 1; + } + return ret; +} + +static int worker_del_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + assert(addr); + const char *key = tcpsess_key(addr); + assert(key); + int ret = map_del_tcp_session(&worker->tcp_waiting, addr); + if (ret == 0) { + tcp_waiting -= 1; + } + return ret; +} + +static struct session* worker_find_tcp_waiting(struct worker_ctx *worker, + const struct sockaddr* addr) +{ + return map_find_tcp_session(&worker->tcp_waiting, addr); } /* Return DNS/TCP message size. */ -static int msg_size(const uint8_t *msg) +static int get_msg_size(const uint8_t *msg) { - return wire_read_u16(msg); + return wire_read_u16(msg); } /* If buffering, close last task as it isn't live yet. */ @@ -901,6 +1695,7 @@ static void discard_buffered(struct session *session) if (session->buffering) { qr_task_free(session->buffering); session->buffering = NULL; + session->msg_hdr_idx = 0; } } @@ -921,153 +1716,259 @@ int worker_end_tcp(struct worker_ctx *worker, uv_handle_t *handle) return 0; } -int worker_process_tcp(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *msg, ssize_t len) +int worker_process_tcp(struct worker_ctx *worker, uv_stream_t *handle, + const uint8_t *msg, ssize_t len) + { if (!worker || !handle) { return kr_error(EINVAL); } /* Connection error or forced disconnect */ struct session *session = handle->data; + if (session->closing) { + return kr_ok(); + } if (len <= 0 || !msg) { /* If we have pending tasks, we must dissociate them from the * connection so they don't try to access closed and freed handle. - * @warning Do not modify task if this is outgoing request as it is shared with originator. + * @warning Do not modify task if this is outgoing request + * as it is shared with originator. */ - if (!session->outgoing) { - for (size_t i = 0; i < session->tasks.len; ++i) { - struct qr_task *task = session->tasks.at[i]; - task->session = NULL; - task->source.handle = NULL; + uv_timer_t *timer = &session->timeout; + uv_timer_stop(timer); + while (session->waiting.len > 0) { + struct qr_task *task = session->waiting.at[0]; + if (session->outgoing) { + qr_task_finalize(task, KR_STATE_FAIL); + } + array_del(session->waiting, 0); + qr_task_unref(task); + session_del_tasks(session, task); + } + while (session->tasks.len > 0) { + struct qr_task *task = session->tasks.at[0]; + if (session->outgoing) { + qr_task_finalize(task, KR_STATE_FAIL); } - session->tasks.len = 0; + session_del_tasks(session, task); } + session_close(session); return kr_error(ECONNRESET); } + if (session->outgoing) { + uv_timer_stop(&session->timeout); + timer_start(session, on_tcp_watchdog_timeout, MAX_TCP_INACTIVITY, 0); + if (session->bytes_to_skip) { + session->buffering = NULL; + ssize_t min_len = MIN(session->bytes_to_skip, len); + len -= min_len; + msg += min_len; + session->bytes_to_skip -= min_len; + if (len < 0 || session->bytes_to_skip < 0) { + /* Something gone wrong. + * Better kill the connection */ + assert(false); + return kr_error(EILSEQ); + } + if (len == 0) { + return kr_ok(); + } + assert(session->bytes_to_skip == 0); + } + } + int submitted = 0; struct qr_task *task = session->buffering; + knot_pkt_t *pkt_buf = NULL; + if (task) { + pkt_buf = task->pktbuf; + } else { + /* Update DNS header in session->msg_hdr* */ + assert(session->msg_hdr_idx <= sizeof(session->msg_hdr)); + ssize_t hdr_amount = sizeof(session->msg_hdr) - + session->msg_hdr_idx; + if (hdr_amount > len) { + hdr_amount = len; + } + if (hdr_amount > 0) { + memcpy(session->msg_hdr + session->msg_hdr_idx, msg, hdr_amount); + session->msg_hdr_idx += hdr_amount; + len -= hdr_amount; + msg += hdr_amount; + } + if (len == 0) { /* no data beyond msg_hdr -> not much to do */ + return kr_ok(); + } + assert(session->msg_hdr_idx == sizeof(session->msg_hdr)); + session->msg_hdr_idx = 0; + uint16_t msg_size = get_msg_size(session->msg_hdr); + uint16_t msg_id = knot_wire_get_id(session->msg_hdr + 2); + if (msg_size < KNOT_WIRE_HEADER_SIZE) { + /* better kill the connection; we would probably get out of sync */ + assert(false); + return kr_error(EILSEQ); + } - /* If this is a new query, create a new task that we can use - * to buffer incoming message until it's complete. */ - if (!session->outgoing) { - if (!task) { - /* Get TCP peer name, keep zeroed address if it fails. */ - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - int addr_len = sizeof(addr); - uv_tcp_getpeername((uv_tcp_t *)handle, (struct sockaddr *)&addr, &addr_len); - task = qr_task_create(worker, (uv_handle_t *)handle, (struct sockaddr *)&addr); + /* get task */ + if (!session->outgoing) { + /* This is a new query, create a new task that we can use + * to buffer incoming message until it's complete. */ + struct sockaddr_storage addr_storage; + struct sockaddr *addr = (struct sockaddr *)&addr_storage; + int addr_len = sizeof(addr_storage); + int ret = uv_tcp_getpeername((uv_tcp_t *)handle, addr, &addr_len); + if (ret) { + addr = NULL; /* fallback */ + } + struct request_ctx *ctx = request_create(worker, + (uv_handle_t *)handle, + addr); + if (!ctx) { + assert(false); + return kr_error(ENOMEM); + } + task = qr_task_create(ctx); if (!task) { + assert(false); + request_free(ctx); return kr_error(ENOMEM); } - session->buffering = task; + } else { + /* Start of response from upstream. + * The session task list must contain a task + * with the same msg id. */ + task = find_task(session, msg_id); + /* FIXME: on high load over one connection, it's likely + * that we will get multiple matches sooner or later (!) */ + if (task) { + knot_pkt_clear(task->pktbuf); + } else { + /* TODO: only ignore one message without killing connection */ + session->buffering = NULL; + session->bytes_to_skip = msg_size - 2; + ssize_t min_len = MIN(session->bytes_to_skip, len); + len -= min_len; + msg += min_len; + session->bytes_to_skip -= min_len; + if (len < 0 || session->bytes_to_skip < 0) { + /* Something gone wrong. + * Better kill the connection */ + assert(false); + return kr_error(EILSEQ); + } + if (len == 0) { + return kr_ok(); + } + assert(session->bytes_to_skip == 0); + int ret = worker_process_tcp(worker, handle, msg, len); + submitted += ret; + return submitted; + } } - } else { - assert(session->tasks.len > 0); - task = array_tail(session->tasks); + + pkt_buf = task->pktbuf; + knot_wire_set_id(pkt_buf->wire, msg_id); + pkt_buf->size = 2; + task->bytes_remaining = msg_size - 2; + session->buffering = task; } - /* At this point session must have either created new task or it's already assigned. */ + /* At this point session must have either created new task + * or it's already assigned. */ assert(task); assert(len > 0); - /* Start reading DNS/TCP message length */ - knot_pkt_t *pkt_buf = task->pktbuf; - if (task->bytes_remaining == 0 && pkt_buf->size == 0) { - knot_pkt_clear(pkt_buf); - /* Read only one byte as TCP fragment may end at a 1B boundary - * which would lead to OOB read or improper reassembly length. */ - pkt_buf->size = 1; - pkt_buf->wire[0] = msg[0]; - len -= 1; - msg += 1; - if (len == 0) { - return 0; - } - } - /* Finish reading DNS/TCP message length. */ - if (task->bytes_remaining == 0 && pkt_buf->size == 1) { - pkt_buf->wire[1] = msg[0]; - ssize_t nbytes = msg_size(pkt_buf->wire); - len -= 1; - msg += 1; - /* Cut off fragment length and start reading DNS message. */ - pkt_buf->size = 0; - task->bytes_remaining = nbytes; - } /* Message is too long, can't process it. */ ssize_t to_read = MIN(len, task->bytes_remaining); if (pkt_buf->size + to_read > pkt_buf->max_size) { pkt_buf->size = 0; task->bytes_remaining = 0; + /* TODO: only ignore one message without killing connection */ + session->buffering = NULL; return kr_error(EMSGSIZE); } /* Buffer message and check if it's complete */ memcpy(pkt_buf->wire + pkt_buf->size, msg, to_read); pkt_buf->size += to_read; - if (to_read >= task->bytes_remaining) { - task->bytes_remaining = 0; + task->bytes_remaining -= to_read; + if (task->bytes_remaining == 0) { + /* Message was assembled, clear temporary. */ + session->buffering = NULL; + session->msg_hdr_idx = 0; + session_del_tasks(session, task); /* Parse the packet and start resolving complete query */ int ret = parse_packet(pkt_buf); if (ret == 0 && !session->outgoing) { - ret = qr_task_start(task, pkt_buf); - if (ret != 0) { - return ret; + /* Start only new queries, + * not subrequests that are already pending */ + ret = request_start(task->ctx, pkt_buf); + assert(ret == 0); + if (ret == 0) { + ret = qr_task_register(task, session); } - ret = qr_task_register(task, session); - if (ret != 0) { - return ret; + if (ret == 0) { + submitted += 1; + } + if (task->leading) { + assert(false); } - /* Task is now registered in session, clear temporary. */ - session->buffering = NULL; - submitted += 1; } - /* Start only new queries, not subrequests that are already pending */ if (ret == 0) { - ret = qr_task_step(task, NULL, pkt_buf); + const struct sockaddr *addr = session->outgoing ? &session->peer.ip : NULL; + ret = qr_task_step(task, addr, pkt_buf); } - /* Process next message part in the stream if no error so far */ - if (ret != 0) { - return ret; - } - if (len - to_read > 0 && !session->outgoing) { + if (len - to_read > 0) { + /* TODO: this is simple via iteration; recursion doesn't really help */ ret = worker_process_tcp(worker, handle, msg + to_read, len - to_read); if (ret < 0) { return ret; } submitted += ret; } - } else { - task->bytes_remaining -= to_read; } return submitted; } -int worker_resolve(struct worker_ctx *worker, knot_pkt_t *query, struct kr_qflags options, - worker_cb_t on_complete, void *baton) +int worker_resolve(struct worker_ctx *worker, knot_pkt_t *query, + struct kr_qflags options, worker_cb_t on_complete, + void *baton) { if (!worker || !query) { + assert(false); return kr_error(EINVAL); } + struct request_ctx *ctx = request_create(worker, NULL, NULL); + if (!ctx) { + return kr_error(ENOMEM); + } + /* Create task */ - struct qr_task *task = qr_task_create(worker, NULL, NULL); + struct qr_task *task = qr_task_create(ctx); if (!task) { + request_free(ctx); return kr_error(ENOMEM); } task->baton = baton; task->on_complete = on_complete; /* Start task */ - int ret = qr_task_start(task, query); + int ret = request_start(ctx, query); /* Set options late, as qr_task_start() -> kr_resolve_begin() rewrite it. */ - kr_qflags_set(&task->req.options, options); + kr_qflags_set(&task->ctx->req.options, options); if (ret != 0) { + request_free(ctx); qr_task_unref(task); return ret; } return qr_task_step(task, NULL, query); } +void worker_session_close(struct session *session) +{ + session_close(session); +} + /** Reserve worker buffers */ static int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen) { @@ -1076,13 +1977,17 @@ static int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen) array_init(worker->pool_sessions); if (array_reserve(worker->pool_mp, ring_maxlen) || array_reserve(worker->pool_ioreq, ring_maxlen) || - array_reserve(worker->pool_sessions, ring_maxlen)) + array_reserve(worker->pool_sessions, ring_maxlen)) { return kr_error(ENOMEM); + } memset(&worker->pkt_pool, 0, sizeof(worker->pkt_pool)); worker->pkt_pool.ctx = mp_new (4 * sizeof(knot_pkt_t)); worker->pkt_pool.alloc = (knot_mm_alloc_t) mp_alloc; worker->outgoing = map_make(); + worker->tcp_connected = map_make(); + worker->tcp_waiting = map_make(); worker->tcp_pipeline_max = MAX_PIPELINED; + memset(&worker->stats, 0, sizeof(worker->stats)); return kr_ok(); } @@ -1102,6 +2007,8 @@ void worker_reclaim(struct worker_ctx *worker) mp_delete(worker->pkt_pool.ctx); worker->pkt_pool.ctx = NULL; map_clear(&worker->outgoing); + map_clear(&worker->tcp_connected); + map_clear(&worker->tcp_waiting); } struct worker_ctx *worker_create(struct engine *engine, knot_mm_t *pool, diff --git a/daemon/worker.h b/daemon/worker.h index 55357e88f..97883e761 100644 --- a/daemon/worker.h +++ b/daemon/worker.h @@ -21,8 +21,12 @@ #include "lib/generic/map.h" +/** Query resolution task (opaque). */ +struct qr_task; /** Worker state (opaque). */ struct worker_ctx; +/** Transport session (opaque). */ +struct session; /** Worker callback */ typedef void (*worker_cb_t)(struct worker_ctx *worker, struct kr_request *req, void *baton); @@ -31,14 +35,20 @@ struct worker_ctx *worker_create(struct engine *engine, knot_mm_t *pool, int worker_id, int worker_count); /** - * Process incoming packet (query or answer to subrequest). + * Process an incoming packet (query from a client or answer from upstream). + * + * @param worker the singleton worker + * @param handle socket through which the request came + * @param query the packet, or NULL on an error from the transport layer + * @param addr the address from which the packet came (or NULL, possibly, on error) * @return 0 or an error code */ int worker_submit(struct worker_ctx *worker, uv_handle_t *handle, knot_pkt_t *query, const struct sockaddr* addr); /** - * Process incoming DNS/TCP message fragment(s). + * Process incoming DNS message fragment(s) that arrived over a stream (TCP, TLS). + * * If the fragment contains only a partial message, it is buffered. * If the fragment contains a complete query or completes current fragment, execute it. * @return the number of newly-completed requests (>=0) or an error code @@ -55,6 +65,7 @@ int worker_end_tcp(struct worker_ctx *worker, uv_handle_t *handle); /** * Schedule query for resolution. * + * After resolution finishes, invoke on_complete with baton. * @return 0 or an error code * * @note the options passed are |-combined with struct kr_context::options @@ -66,15 +77,27 @@ int worker_resolve(struct worker_ctx *worker, knot_pkt_t *query, struct kr_qflag /** Collect worker mempools */ void worker_reclaim(struct worker_ctx *worker); +/** Closes given session */ +void worker_session_close(struct session *session); + /** @cond internal */ /** Number of request within timeout window. */ #define MAX_PENDING KR_NSREP_MAXADDR +/** Maximum response time from TCP upstream, milliseconds */ +#define MAX_TCP_INACTIVITY 10000 + /** Freelist of available mempools. */ typedef array_t(void *) mp_freelist_t; +/** List of query resolution tasks. */ +typedef array_t(struct qr_task *) qr_tasklist_t; + +/** Session list. */ +typedef array_t(struct session *) qr_sessionlist_t; + /** \details Worker state is meant to persist during the whole life of daemon. */ struct worker_ctx { struct engine *engine; @@ -94,6 +117,7 @@ struct worker_ctx { #endif struct { size_t concurrent; + size_t rconcurrent; size_t udp; size_t tcp; size_t ipv4; @@ -103,6 +127,12 @@ struct worker_ctx { size_t timeout; } stats; + bool too_many_open; + size_t rconcurrent_highwatermark; + /* List of active outbound TCP sessions */ + map_t tcp_connected; + /* List of outbound TCP sessions waiting to be accepted */ + map_t tcp_waiting; map_t outgoing; mp_freelist_t pool_mp; mp_freelist_t pool_ioreq; @@ -110,34 +140,6 @@ struct worker_ctx { knot_mm_t pkt_pool; }; -/** Query resolution task. */ -struct qr_task -{ - struct kr_request req; - struct worker_ctx *worker; - struct session *session; - knot_pkt_t *pktbuf; - array_t(struct qr_task *) waiting; - uv_handle_t *pending[MAX_PENDING]; - uint16_t pending_count; - uint16_t addrlist_count; - uint16_t addrlist_turn; - uint16_t timeouts; - uint16_t iter_count; - uint16_t bytes_remaining; - struct sockaddr *addrlist; - uv_timer_t *timeout; - worker_cb_t on_complete; - void *baton; - struct { - union inaddr addr; - union inaddr dst_addr; - uv_handle_t *handle; - } source; - uint32_t refs; - bool finished : 1; - bool leading : 1; -}; /** @endcond */ diff --git a/lib/layer/iterate.c b/lib/layer/iterate.c index 9d5ded5ce..9a9d31140 100644 --- a/lib/layer/iterate.c +++ b/lib/layer/iterate.c @@ -823,7 +823,7 @@ int kr_make_query(struct kr_query *query, knot_pkt_t *pkt) char name_str[KNOT_DNAME_MAXLEN], type_str[16]; knot_dname_to_str(name_str, query->sname, sizeof(name_str)); knot_rrtype_to_string(query->stype, type_str, sizeof(type_str)); - QVERBOSE_MSG(query, "'%s' type '%s' id was assigned, parent id %hu\n", + QVERBOSE_MSG(query, "'%s' type '%s' id was assigned, parent id %u\n", name_str, type_str, query->parent ? query->parent->id : 0); } return kr_ok(); diff --git a/lib/resolve.c b/lib/resolve.c index 04445afd3..1589a2583 100644 --- a/lib/resolve.c +++ b/lib/resolve.c @@ -1520,7 +1520,6 @@ int kr_resolve_checkout(struct kr_request *request, struct sockaddr *src, if (ret != 0) { return kr_error(EINVAL); } - WITH_VERBOSE { char qname_str[KNOT_DNAME_MAXLEN], zonecut_str[KNOT_DNAME_MAXLEN], ns_str[INET6_ADDRSTRLEN], type_str[16]; knot_dname_to_str(qname_str, knot_pkt_qname(packet), sizeof(qname_str)); diff --git a/lib/utils.c b/lib/utils.c index bb4b50e1e..4cdff4297 100644 --- a/lib/utils.c +++ b/lib/utils.c @@ -318,6 +318,31 @@ uint16_t kr_inaddr_port(const struct sockaddr *addr) } } +int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen) +{ + int ret = kr_ok(); + if (!addr || !buf || !buflen) { + return kr_error(EINVAL); + } + + char str[INET6_ADDRSTRLEN + 6]; + if (!inet_ntop(addr->sa_family, kr_inaddr(addr), str, sizeof(str))) { + return kr_error(errno); + } + int len = strlen(str); + str[len] = '#'; + u16tostr((uint8_t *)&str[len + 1], kr_inaddr_port(addr)); + len += 6; + str[len] = 0; + if (len >= *buflen) { + ret = kr_error(ENOSPC); + } else { + memcpy(buf, str, len + 1); + } + *buflen = len; + return ret; +} + int kr_straddr_family(const char *addr) { if (!addr) { @@ -396,6 +421,84 @@ int kr_straddr_subnet(void *dst, const char *addr) return bit_len; } +int kr_straddr_split(const char *addr, char *buf, size_t buflen, uint16_t *port) +{ + const int base = 10; + long p = 0; + size_t addrlen = strlen(addr); + char *p_start = strchr(addr, '@'); + char *p_end; + + if (!p_start) { + p_start = strchr(addr, '#'); + } + + if (p_start) { + if (p_start[1] != '\0'){ + p = strtol(p_start + 1, &p_end, base); + if (*p_end != '\0' || p <= 0 || p > UINT16_MAX) { + return kr_error(EINVAL); + } + } + addrlen = p_start - addr; + } + + /* Check if address is valid. */ + if (addrlen >= INET6_ADDRSTRLEN) { + return kr_error(EINVAL); + } + + char str[INET6_ADDRSTRLEN]; + struct sockaddr_storage ss; + + memcpy(str, addr, addrlen); str[addrlen] = '\0'; + + int family = kr_straddr_family(str); + if (family == kr_error(EINVAL) || !inet_pton(family, str, &ss)) { + return kr_error(EINVAL); + } + + /* Address and port contains valid values, return it to caller */ + if (buf) { + if (addrlen >= buflen) { + return kr_error(ENOSPC); + } + memcpy(buf, addr, addrlen); buf[addrlen] = '\0'; + } + if (port) { + *port = (uint16_t)p; + } + + return kr_ok(); +} + +int kr_straddr_join(const char *addr, uint16_t port, char *buf, size_t *buflen) +{ + if (!addr || !buf || !buflen) { + return kr_error(EINVAL); + } + + struct sockaddr_storage ss; + int family = kr_straddr_family(addr); + if (family == kr_error(EINVAL) || !inet_pton(family, addr, &ss)) { + return kr_error(EINVAL); + } + + int len = strlen(addr); + if (len + 6 >= *buflen) { + return kr_error(ENOSPC); + } + + memcpy(buf, addr, len + 1); + buf[len] = '#'; + u16tostr((uint8_t *)&buf[len + 1], port); + len += 6; + buf[len] = 0; + *buflen = len; + + return kr_ok(); +} + int kr_bitcmp(const char *a, const char *b, int bits) { /* We're using the function from lua directly, so at least for now diff --git a/lib/utils.h b/lib/utils.h index de727771a..5cd8d75d4 100644 --- a/lib/utils.h +++ b/lib/utils.h @@ -197,10 +197,13 @@ int kr_inaddr_len(const struct sockaddr *addr); /** Port. */ KR_EXPORT KR_PURE uint16_t kr_inaddr_port(const struct sockaddr *addr); +/** String representation for given address as "#" */ +KR_EXPORT +int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen); /** Return address type for string. */ KR_EXPORT KR_PURE int kr_straddr_family(const char *addr); -/** Return address length in given family. */ +/** Return address length in given family (struct in*_addr). */ KR_EXPORT KR_CONST int kr_family_len(int family); /** Create a sockaddr* from string+port representation (also accepts IPv6 link-local). */ @@ -211,6 +214,26 @@ struct sockaddr * kr_straddr_socket(const char *addr, int port); KR_EXPORT int kr_straddr_subnet(void *dst, const char *addr); +/** Splits ip address specified as "addr@port" or "addr#port" into addr and port + * and performs validation. + * @note if #port part isn't present, then port will be set to 0. + * buf and\or port can be set to NULL. + * @return kr_error(EINVAL) - addr part doesn't contains valid ip address or + * #port part is out-of-range (either < 0 either > UINT16_MAX) + * kr_error(ENOSP) - buflen is too small + */ +KR_EXPORT +int kr_straddr_split(const char *addr, char *buf, size_t buflen, uint16_t *port); +/** Formats ip address and port in "addr#port" format. + * and performs validation. + * @note Port always formatted as five-character string with leading zeros. + * @return kr_error(EINVAL) - addr or buf is NULL or buflen is 0 or + * addr doesn't contain a valid ip address + * kr_error(ENOSP) - buflen is too small + */ +KR_EXPORT +int kr_straddr_join(const char *addr, uint16_t port, char *buf, size_t *buflen); + /** Compare memory bitwise. The semantics is "the same" as for memcmp(). * The partial byte is considered with more-significant bits first, * so this is e.g. suitable for comparing IP prefixes. */ @@ -300,6 +323,19 @@ static inline const char *lua_push_printf(lua_State *L, const char *fmt, ...) return ret; } +/** @internal Return string representation of addr. + * @note return pointer to static string + */ +static inline char *kr_straddr(const struct sockaddr *addr) +{ + assert(addr != NULL); + /* We are the sinle-threaded application */ + static char str[INET6_ADDRSTRLEN + 6]; + size_t len = sizeof(str); + int ret = kr_inaddr_str(addr, str, &len); + return ret != kr_ok() || len == 0 ? NULL : str; +} + /** The current time in monotonic milliseconds. * * \note it may be outdated in case of long callbacks; see uv_now(). diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua index 2e975a673..9fa135e32 100644 --- a/modules/policy/policy.lua +++ b/modules/policy/policy.lua @@ -120,6 +120,32 @@ local function forward(target) end end +-- Forward request and all subrequests to upstream over TCP; validate answers +local function tcp_forward(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v)) + assert(#list <= 4, 'at most 4 TCP_FORWARD targets are supported') + end + else + table.insert(list, addr2sock(target)) + end + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + req.options.TCP = true + qry.flags.TCP = true + set_nslist(qry, list) + return state + end +end + -- Rewrite records in packet local function reroute(tbl, names) -- Import renumbering rules @@ -236,7 +262,8 @@ end local policy = { -- Policies PASS = 1, DENY = 2, DROP = 3, TC = 4, QTRACE = 5, - FORWARD = forward, STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags, + FORWARD = forward, TCP_FORWARD = tcp_forward, + STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags, -- Special values ANY = 0, } -- 2.47.2