]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon/worker: fixes error handling from TLS writes
authorMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 17 Aug 2018 07:43:36 +0000 (00:43 -0700)
committerGrigorii Demidov <grigorii.demidov@nic.cz>
Fri, 14 Sep 2018 12:40:38 +0000 (14:40 +0200)
The error handling loop for uncorking TLS data was wrong, as the
underlying push function is asynchronous and there's no relationship
between completed DNS packet writes and number of TLS message writes.
In case of the asynchronous function, the buffered data must be valid
until the write is complete, currently this is not guaranteed and
loading the resolver with pipelined requests results in memory errors:

```
$ getdns_query @127.0.0.1#853 -s -a -s -l L -B -F queries -q
...
==47111==ERROR: AddressSanitizer: heap-use-after-free on address 0x6290040a1253 at pc 0x00010da960d3 bp 0x7ffee2628b30 sp 0x7ffee26282e0
READ of size 499 at 0x6290040a1253 thread T0
    #0 0x10da960d2 in wrap_write (libclang_rt.asan_osx_dynamic.dylib:x86_64h+0x1f0d2)
    #1 0x10d855971 in uv__write (libuv.1.dylib:x86_64+0xf971)
    #2 0x10d85422e in uv__stream_io (libuv.1.dylib:x86_64+0xe22e)
    #3 0x10d85b35a in uv__io_poll (libuv.1.dylib:x86_64+0x1535a)
    #4 0x10d84c644 in uv_run (libuv.1.dylib:x86_64+0x6644)
    #5 0x10d602ddf in main main.c:422
    #6 0x7fff6a28a014 in start (libdyld.dylib:x86_64+0x1014)

0x6290040a1253 is located 83 bytes inside of 16895-byte region [0x6290040a1200,0x6290040a53ff)
freed by thread T0 here:
    #0 0x10dacdfdd in wrap_free (libclang_rt.asan_osx_dynamic.dylib:x86_64h+0x56fdd)
    #1 0x10d913c2e in _mbuffer_head_remove_bytes (libgnutls.30.dylib:x86_64+0xbc2e)
    #2 0x10d915080 in _gnutls_io_write_flush (libgnutls.30.dylib:x86_64+0xd080)
    #3 0x10d90ca18 in _gnutls_send_tlen_int (libgnutls.30.dylib:x86_64+0x4a18)
    #4 0x10d90edde in gnutls_record_send2 (libgnutls.30.dylib:x86_64+0x6dde)
    #5 0x10d90f085 in gnutls_record_uncork (libgnutls.30.dylib:x86_64+0x7085)
    #6 0x10d5f6569 in tls_push tls.c:238
    #7 0x10d5e5b2a in qr_task_send worker.c:1002
    #8 0x10d5e2ea6 in qr_task_finalize worker.c:1562
    #9 0x10d5dab99 in qr_task_step worker.c
    #10 0x10d5e12fe in worker_process_tcp worker.c:2410
```

The current implementation adds opportunistic uv_try_write which
either writes the requested data, or returns UV_EAGAIN or an error,
which then falls back to slower asynchronous write that copies the buffered data.

The function signature is changed from simple write to vectorized write.

This also enables TLS False Start to save 1RTT when possible.

daemon/bindings.c
daemon/io.c
daemon/io.h
daemon/network.c
daemon/network.h
daemon/tls.c
daemon/tls.h
daemon/worker.c
daemon/worker.h
lib/utils.c
lib/utils.h

index 313b2bb268e561e266d07fc3295fada7d5be3fa3..f7b2e7ea4aabb7aca6055f0a53faa0655b924750 100644 (file)
@@ -793,37 +793,53 @@ static int net_outgoing(lua_State *L, int family)
 static int net_outgoing_v4(lua_State *L) { return net_outgoing(L, AF_INET); }
 static int net_outgoing_v6(lua_State *L) { return net_outgoing(L, AF_INET6); }
 
-static int net_tcp_in_idle(lua_State *L)
+static int net_update_timeout(lua_State *L, uint64_t *timeout, const char *name)
 {
-       struct engine *engine = engine_luaget(L);
-       struct network *net = &engine->net;
-
        /* Only return current idle timeout. */
        if (lua_gettop(L) == 0) {
-               lua_pushnumber(L, net->tcp.in_idle_timeout);
+               lua_pushnumber(L, *timeout);
                return 1;
        }
 
        if ((lua_gettop(L) != 1)) {
-               lua_pushstring(L, "net.tcp_in_idle takes one parameter: (\"idle timeout\")");
+               lua_pushstring(L, name);
+               lua_pushstring(L, " takes one parameter: (\"idle timeout\")");
                lua_error(L);
        }
 
        if (lua_isnumber(L, 1)) {
                int idle_timeout = lua_tointeger(L, 1);
                if (idle_timeout <= 0) {
-                       lua_pushstring(L, "net.tcp_in_idle parameter has to be positive number");
+                       lua_pushstring(L, name);
+                       lua_pushstring(L, " parameter has to be positive number");
                        lua_error(L);
                }
-               net->tcp.in_idle_timeout = idle_timeout;
+               *timeout = idle_timeout;
        } else {
-               lua_pushstring(L, "net.tcp_in_idle parameter has to be positive number");
+               lua_pushstring(L, name);
+               lua_pushstring(L, " parameter has to be positive number");
                lua_error(L);
        }
        lua_pushboolean(L, true);
        return 1;
 }
 
+static int net_tcp_in_idle(lua_State *L)
+{
+       struct engine *engine = engine_luaget(L);
+       struct network *net = &engine->net;
+
+       return net_update_timeout(L, &net->tcp.in_idle_timeout, "net.tcp_in_idle");
+}
+
+static int net_tls_handshake_timeout(lua_State *L)
+{
+       struct engine *engine = engine_luaget(L);
+       struct network *net = &engine->net;
+
+       return net_update_timeout(L, &net->tcp.tls_handshake_timeout, "net.tls_handshake_timeout");
+}
+
 int lib_net(lua_State *L)
 {
        static const luaL_Reg lib[] = {
@@ -842,6 +858,7 @@ int lib_net(lua_State *L)
                { "outgoing_v4",  net_outgoing_v4 },
                { "outgoing_v6",  net_outgoing_v6 },
                { "tcp_in_idle",  net_tcp_in_idle },
+               { "tls_handshake_timeout",  net_tls_handshake_timeout },
                { NULL, NULL }
        };
        register_lib(L, "net", lib);
index e6ee4a26022f84a4d63416713221aac4fc058251..ae39261b6b1bce87343f6a134bc20607332d7689 100644 (file)
@@ -108,17 +108,6 @@ static void session_release(struct worker_ctx *worker, uv_handle_t *handle)
        }
 }
 
-static uv_stream_t *handle_borrow(uv_loop_t *loop)
-{
-       struct worker_ctx *worker = loop->data;
-       void *req = worker_iohandle_borrow(worker);
-       if (!req) {
-               return NULL;
-       }
-
-       return (uv_stream_t *)req;
-}
-
 static void handle_getbuf(uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf)
 {
        /* Worker has single buffer which is reused for all incoming
@@ -276,12 +265,21 @@ static void _tcp_accept(uv_stream_t *master, int status, bool tls)
                return;
        }
 
-       uv_stream_t *client = handle_borrow(master->loop);
+       struct worker_ctx *worker = (struct worker_ctx *)master->loop->data;
+       uv_stream_t *client = worker_iohandle_borrow(worker);
        if (!client) {
                return;
        }
        memset(client, 0, sizeof(*client));
-       io_create(master->loop, (uv_handle_t *)client, SOCK_STREAM);
+       int res = io_create(master->loop, (uv_handle_t *)client, SOCK_STREAM, 0);
+       if (res) {
+               if (res == UV_EMFILE) {
+                       worker->too_many_open = true;
+                       worker->rconcurrent_highwatermark = worker->stats.rconcurrent;
+               }
+               worker_iohandle_release(worker, client);
+               return;
+       }
        if (uv_accept(master, client) != 0) {
                uv_close((uv_handle_t *)client, io_release);
                return;
@@ -297,11 +295,11 @@ static void _tcp_accept(uv_stream_t *master, int status, bool tls)
        int addr_len = sizeof(union inaddr);
        int ret = uv_tcp_getpeername((uv_tcp_t *)client, addr, &addr_len);
        if (ret || addr->sa_family == AF_UNSPEC) {
+               worker_iohandle_release(worker, client);
                worker_session_close(session);
                return;
        }
 
-       const struct worker_ctx *worker = (struct worker_ctx *)master->loop->data;
        const struct engine *engine = worker->engine;
        const struct network *net = &engine->net;
        uint64_t idle_in_timeout = net->tcp.in_idle_timeout;
@@ -424,16 +422,18 @@ int tcp_bindfd_tls(uv_tcp_t *handle, int fd)
        return _tcp_bindfd(handle, fd, tls_accept);
 }
 
-void io_create(uv_loop_t *loop, uv_handle_t *handle, int type)
+int io_create(uv_loop_t *loop, uv_handle_t *handle, int type, unsigned family)
 {
-       int ret = -1;
+       int ret = 0;
        if (type == SOCK_DGRAM) {
                ret = uv_udp_init(loop, (uv_udp_t *)handle);
        } else if (type == SOCK_STREAM) {
-               ret = uv_tcp_init(loop, (uv_tcp_t *)handle);
+               ret = uv_tcp_init_ex(loop, (uv_tcp_t *)handle, family);
                uv_tcp_nodelay((uv_tcp_t *)handle, 1);
        }
-       assert(ret == 0);
+       if (ret != 0) {
+               return ret;
+       }
        struct worker_ctx *worker = loop->data;
        struct session *session = session_borrow(worker);
        assert(session);
@@ -441,6 +441,7 @@ void io_create(uv_loop_t *loop, uv_handle_t *handle, int type)
        handle->data = session;
        session->timeout.data = session;
        uv_timer_init(worker->loop, &session->timeout);
+       return ret;
 }
 
 void io_deinit(uv_handle_t *handle)
index 24c0c26e7be6d5e50f600005d278f9c19fc15c73..51976b07e3283ec72dbccfb3931a31c699f44673 100644 (file)
@@ -60,7 +60,7 @@ int tcp_bindfd(uv_tcp_t *handle, int fd);
 int tcp_bindfd_tls(uv_tcp_t *handle, int fd);
 
 /** Initialize the handle, incl. ->data = struct session * instance. type = SOCK_* */
-void io_create(uv_loop_t *loop, uv_handle_t *handle, int type);
+int io_create(uv_loop_t *loop, uv_handle_t *handle, int type, unsigned family);
 void io_deinit(uv_handle_t *handle);
 void io_free(uv_handle_t *handle);
 
index a790474deb28f7b3a71bc641b1f5b38e4cc3bc16..73282786313b8bd584019b4a2656419a9bf9b4a4 100644 (file)
@@ -53,8 +53,9 @@ void network_init(struct network *net, uv_loop_t *loop)
                net->endpoints = map_make(NULL);
                net->tls_client_params = map_make(NULL);
                net->tls_session_ticket_ctx = /* unsync. random, by default */
-                       tls_session_ticket_ctx_create(loop, NULL, 0);
+               tls_session_ticket_ctx_create(loop, NULL, 0);
                net->tcp.in_idle_timeout = 10000;
+               net->tcp.tls_handshake_timeout = TLS_MAX_HANDSHAKE_TIME;
        }
 }
 
index 15aac9e0ba1cb209d110178c7f5cd048d749ab63..cc7f2785402cbd5385549b2ac6d3c6decf8c3837 100644 (file)
@@ -44,6 +44,7 @@ typedef array_t(struct endpoint*) endpoint_array_t;
 
 struct net_tcp_param {
        uint64_t in_idle_timeout;
+       uint64_t tls_handshake_timeout;
 };
 
 struct tls_session_ticket_ctx;
index 5424cfc59f6419b4f114b3f222c69978a97bd3b7..b182c4e0492f94d512f0857fd5cb38c17c62259a 100644 (file)
@@ -90,6 +90,150 @@ static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
        return transfer;
 }
 
+static void on_write_complete(uv_write_t *req, int status)
+{
+       assert(req->data != NULL);
+       free(req->data);
+       free(req);
+}
+
+static bool stream_queue_is_empty(uv_stream_t *handle)
+{
+#if UV_VERSION_HEX >= 0x011900
+       return uv_stream_get_write_queue_size(handle) == 0;
+#else
+       /* Assume best case */
+       return true;
+#endif
+}
+
+static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt)
+{
+       struct tls_common_ctx *t = (struct tls_common_ctx *)h;
+
+       if (t == NULL) {
+               errno = EFAULT;
+               return -1;
+       }
+
+       if (iovcnt == 0) {
+               return 0;
+       }
+
+       assert(t->session && t->session->handle &&
+              t->session->handle->type == UV_TCP);
+       uv_stream_t *handle = (uv_stream_t *)t->session->handle;
+
+       /*
+        * This is a little bit complicated. There are two different writes:
+        * 1. Immediate, these don't need to own the buffered data and return immediately
+        * 2. Asynchronous, these need to own the buffers until the write completes
+        * In order to avoid copying the buffer, an immediate write is tried first if possible.
+        * If it isn't possible to write the data without queueing, an asynchronous write
+        * is created (with copied buffered data).
+        */
+
+       size_t total_len = 0;
+       uv_buf_t uv_buf[iovcnt];
+       for (int i = 0; i < iovcnt; ++i) {
+               uv_buf[i].base = iov[i].iov_base;
+               uv_buf[i].len = iov[i].iov_len;
+               total_len += iov[i].iov_len;
+       }
+
+       /* Try to perform the immediate write first to avoid copy */
+       int ret = 0;
+       if (stream_queue_is_empty(handle)) {
+               ret = uv_try_write(handle, uv_buf, iovcnt);
+               DEBUG_MSG("[%s] push %zu <%p> = %d\n",
+                   t->client_side ? "tls_client" : "tls", total_len, h, ret);
+               if (ret >= 0 || ret != UV_EAGAIN) {
+                       return ret;
+               }
+       }
+
+       /* Fallback when the queue is full, and it's not possible to do an immediate write */
+       char *buf = malloc(total_len);
+       if (buf != NULL) {
+               /* Copy the buffer into owned memory */
+               size_t off = 0;
+               for (int i = 0; i < iovcnt; ++i) {
+                       memcpy(buf + off, uv_buf[i].base, uv_buf[i].len);
+                       off += uv_buf[i].len;
+               }
+               uv_buf[0].base = buf;
+               uv_buf[0].len = total_len;
+
+               /* Create an asynchronous write request */
+               uv_write_t *write_req = calloc(1, sizeof(uv_write_t));
+               if (write_req != NULL) {
+                       write_req->data = buf;
+               } else {
+                       free(buf);
+                       errno = ENOMEM;
+                       ret = -1;
+               }
+
+               /* Perform an asynchronous write with a callback */
+               if (uv_write(write_req, handle, uv_buf, 1, on_write_complete) == 0) {
+                       ret = total_len;
+               } else {
+                       free(buf);
+                       free(write_req);
+                       errno = EIO;
+                       ret = -1;
+               }
+       } else {
+               errno = ENOMEM;
+               ret = -1;
+       }
+
+       DEBUG_MSG("[%s] queued %zu <%p> = %d\n",
+           t->client_side ? "tls_client" : "tls", total_len, h, ret);
+
+       return ret;
+}
+
+/** Perform TLS handshake and handle error codes according to the documentation.
+  * See See https://gnutls.org/manual/html_node/TLS-handshake.html#TLS-handshake
+  * The function returns kr_ok() or success or non fatal error, kr_error(EAGAIN) on blocking, or kr_error(EIO) on fatal error.
+  */
+static int tls_handshake(struct tls_common_ctx *ctx, tls_handshake_cb handshake_cb) {
+       struct session *session = ctx->session;
+       const char *logstring = ctx->client_side ? client_logstring : server_logstring;
+
+       int err = gnutls_handshake(ctx->tls_session);
+       if (err == GNUTLS_E_SUCCESS) {
+               /* Handshake finished, return success */
+               ctx->handshake_state = TLS_HS_DONE;
+               kr_log_verbose("[%s] TLS handshake with %s has completed\n",
+                              logstring,  kr_straddr(&session->peer.ip));
+               if (handshake_cb) {
+                       handshake_cb(session, 0);
+               }
+       } else if (err == GNUTLS_E_AGAIN) {
+               return kr_error(EAGAIN);
+       } else if (gnutls_error_is_fatal(err)) {
+               /* Fatal errors, return error as it's not recoverable */
+               kr_log_verbose("[%s] gnutls_handshake failed: %s (%d)\n",
+                            logstring,
+                            gnutls_strerror_name(err), err);
+               if (handshake_cb) {
+                       handshake_cb(session, -1);
+               }
+               return kr_error(EIO);
+       } else if (err == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+               /* Handle warning when in verbose mode */
+               const char *alert_name = gnutls_alert_get_name(gnutls_alert_get(ctx->tls_session));
+               if (alert_name != NULL) {
+                       kr_log_verbose("[%s] TLS alert from %s received: %s\n",
+                                      logstring, kr_straddr(&session->peer.ip), alert_name);
+               }
+       }
+       return kr_ok();
+}
+
+
 struct tls_ctx_t *tls_new(struct worker_ctx *worker)
 {
        assert(worker != NULL);
@@ -158,7 +302,7 @@ struct tls_ctx_t *tls_new(struct worker_ctx *worker)
        tls->c.client_side = false;
 
        gnutls_transport_set_pull_function(tls->c.tls_session, kres_gnutls_pull);
-       gnutls_transport_set_push_function(tls->c.tls_session, worker_gnutls_push);
+       gnutls_transport_set_vec_push_function(tls->c.tls_session, kres_gnutls_vec_push);
        gnutls_transport_set_ptr(tls->c.tls_session, tls);
 
        if (net->tls_session_ticket_ctx) {
@@ -202,7 +346,7 @@ void tls_free(struct tls_ctx_t *tls)
        free(tls);
 }
 
-int tls_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt)
+int tls_write(uv_write_t *req, uv_handle_t *handle, knot_pkt_t *pkt, uv_write_cb cb)
 {
        if (!pkt || !handle || !handle->data) {
                return kr_error(EINVAL);
@@ -219,10 +363,6 @@ int tls_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt)
        const char *logstring = tls_ctx->client_side ? client_logstring : server_logstring;
        gnutls_session_t tls_session = tls_ctx->tls_session;
 
-       tls_ctx->task = task;
-
-       assert(gnutls_record_check_corked(tls_session) == 0);
-
        gnutls_record_cork(tls_session);
        ssize_t count = 0;
        if ((count = gnutls_record_send(tls_session, &pkt_size, sizeof(pkt_size)) < 0) ||
@@ -232,36 +372,24 @@ int tls_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt)
                return kr_error(EIO);
        }
 
-       ssize_t submitted = 0;
-       ssize_t retries = 0;
-       do {
-               count = gnutls_record_uncork(tls_session, 0);
-               if (count < 0) {
-                       if (gnutls_error_is_fatal(count)) {
-                               kr_log_error("[%s] gnutls_record_uncork failed: %s (%zd)\n",
-                                            logstring, gnutls_strerror_name(count), count);
-                               return kr_error(EIO);
-                       }
-                       if (++retries > TLS_MAX_UNCORK_RETRIES) {
-                               kr_log_error("[%s] gnutls_record_uncork: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n",
-                                            logstring, retries, gnutls_strerror_name(count), count);
-                               return kr_error(EIO);
-                       }
-               } else if (count != 0) {
-                       submitted += count;
-                       retries = 0;
-               } else if (gnutls_record_check_corked(tls_session) != 0) {
-                       if (++retries > TLS_MAX_UNCORK_RETRIES) {
-                               kr_log_error("[%s] gnutls_record_uncork: too many retries (%zd)\n",
-                                            logstring, retries);
-                               return kr_error(EIO);
-                       }
-               } else if (submitted != sizeof(pkt_size) + pkt->size) {
-                       kr_log_error("[%s] gnutls_record_uncork didn't send all data(%zd of %zd)\n",
-                                    logstring, submitted, sizeof(pkt_size) + pkt->size);
-                       return kr_error(EIO);
-               }
-       } while (submitted != sizeof(pkt_size) + pkt->size);
+       const ssize_t submitted = sizeof(pkt_size) + pkt->size;
+
+       int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT);
+       if (gnutls_error_is_fatal(ret)) {
+               kr_log_error("[%s] gnutls_record_uncork failed: %s (%d)\n",
+                            logstring, gnutls_strerror_name(ret), ret);
+               return kr_error(EIO);
+       }
+
+       if (ret != submitted) {
+               kr_log_error("[%s] gnutls_record_uncork didn't send all data (%d of %zd)\n",
+                            logstring, ret, submitted);
+               return kr_error(EIO);
+       }
+
+       /* The data is now accepted in gnutls internal buffers, the message can be treated as sent */
+       req->handle = (uv_stream_t *)handle;
+       cb(req, 0);
 
        return kr_ok();
 }
@@ -283,47 +411,38 @@ int tls_process(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *b
        tls_p->nread = nread >= 0 ? nread : 0;
        tls_p->consumed = 0;
 
-       /* Ensure TLS handshake is performed before receiving data. */
-       while (tls_p->handshake_state == TLS_HS_IN_PROGRESS) {
-               int err = gnutls_handshake(tls_p->tls_session);
-               if (err == GNUTLS_E_SUCCESS) {
-                       tls_p->handshake_state = TLS_HS_DONE;
-                       kr_log_verbose("[%s] TLS handshake with %s has completed\n",
-                                      logstring,  kr_straddr(&session->peer.ip));
-                       if (tls_p->handshake_cb) {
-                               tls_p->handshake_cb(tls_p->session, 0);
-                       }
-               } else if (err == GNUTLS_E_AGAIN) {
-                       return 0;
-               } else if (gnutls_error_is_fatal(err)) {
-                       kr_log_verbose("[%s] gnutls_handshake failed: %s (%d)\n",
-                                    logstring,
-                                    gnutls_strerror_name(err), err);
-                       if (tls_p->handshake_cb) {
-                               tls_p->handshake_cb(tls_p->session, -1);
-                       }
-                       return kr_error(err);
+       /* Ensure TLS handshake is performed before receiving data.
+        * See https://www.gnutls.org/manual/html_node/TLS-handshake.html */
+       while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) {
+               int err = tls_handshake(tls_p, tls_p->handshake_cb);
+               if (err == kr_error(EAGAIN)) {
+                       return 0; /* Wait for more data */
+               } else if (err != kr_ok()) {
+                       return err;
                }
        }
 
+       /* See https://gnutls.org/manual/html_node/Data-transfer-and-termination.html#Data-transfer-and-termination */
        int submitted = 0;
-       bool is_retrying = false;
-       uint64_t retrying_start = 0;
        while (true) {
                ssize_t count = gnutls_record_recv(tls_p->tls_session, tls_p->recv_buf, sizeof(tls_p->recv_buf));
                if (count == GNUTLS_E_AGAIN) {
-                       break;    /* No data available */
-               } else if (count == GNUTLS_E_INTERRUPTED ||
-                          count == GNUTLS_E_REHANDSHAKE) {
-                       if (!is_retrying) {
-                               is_retrying = true;
-                               retrying_start = kr_now();
-                       }
-                       uint64_t elapsed = kr_now() - retrying_start;
-                       if (elapsed > TLS_MAX_HANDSHAKE_TIME) {
-                               return kr_error(EIO);
+                       break; /* No data available */
+               } else if (count == GNUTLS_E_INTERRUPTED) {
+                       continue;
+               } else if (count == GNUTLS_E_REHANDSHAKE) {
+                       /* See https://www.gnutls.org/manual/html_node/Re_002dauthentication.html */
+                       tls_set_hs_state(tls_p, TLS_HS_IN_PROGRESS);
+                       while (tls_p->handshake_state <= TLS_HS_IN_PROGRESS) {
+                               int err = tls_handshake(tls_p, tls_p->handshake_cb);
+                               if (err == kr_error(EAGAIN)) {
+                                       break;
+                               } else if (err != kr_ok()) {
+                                       return err;
+                               }
                        }
-                       continue; /* Try reading again */
+                       /* Wait for more data */
+                       break;
                } else if (count < 0) {
                        kr_log_verbose("[%s] gnutls_record_recv failed: %s (%zd)\n",
                                     logstring, gnutls_strerror_name(count), count);
@@ -334,7 +453,7 @@ int tls_process(struct worker_ctx *worker, uv_stream_t *handle, const uint8_t *b
                if (ret < 0) {
                        return ret;
                }
-               if (count == 0) {
+               if (count <= 0) {
                        break;
                }
                submitted += ret;
@@ -562,9 +681,9 @@ void tls_credentials_free(struct tls_credentials *tls_credentials) {
        free(tls_credentials);
 }
 
-static int client_paramlist_entry_clear(const char *k, void *v, void *baton)
+static int client_paramlist_entry_free(struct tls_client_paramlist_entry *entry)
 {
-       struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)v;
+       DEBUG_MSG("freeing TLS parameters %p\n", entry);
 
        while (entry->ca_files.len > 0) {
                if (entry->ca_files.at[0] != NULL) {
@@ -604,6 +723,75 @@ static int client_paramlist_entry_clear(const char *k, void *v, void *baton)
        return 0;
 }
 
+static void client_paramlist_entry_ref(struct tls_client_paramlist_entry *entry)
+{
+       if (entry != NULL) {
+               entry->refs += 1;
+       }
+}
+
+static void client_paramlist_entry_unref(struct tls_client_paramlist_entry *entry)
+{
+       if (entry != NULL) {
+               assert(entry->refs > 0);
+               entry->refs -= 1;
+
+               /* Last reference frees the object */
+               if (entry->refs == 0) {
+                       client_paramlist_entry_free(entry);
+               }
+       }
+}
+
+static int client_paramlist_entry_clear(const char *k, void *v, void *baton)
+{
+       struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)v;
+       return client_paramlist_entry_free(entry);
+}
+
+struct tls_client_paramlist_entry *tls_client_try_upgrade(map_t *tls_client_paramlist,
+                         const struct sockaddr *addr)
+{
+       /* Opportunistic upgrade from port 53 -> 853 */
+       if (kr_inaddr_port(addr) != KR_DNS_PORT) {
+               return NULL;
+       }
+
+       static char key[INET6_ADDRSTRLEN + 6];
+       size_t keylen = sizeof(key);
+       if (kr_inaddr_str(addr, key, &keylen) != 0) {
+               return NULL;
+       }
+
+       /* Rewrite 053 -> 853 */
+       strcpy(key + keylen - 3, "853");
+
+       return map_get(tls_client_paramlist, key);
+}
+
+int tls_client_params_clear(map_t *tls_client_paramlist, const char *addr, uint16_t port)
+{
+       if (!tls_client_paramlist || !addr) {
+               return kr_error(EINVAL);
+       }
+
+       /* Parameters are OK */
+
+       char key[INET6_ADDRSTRLEN + 6];
+       size_t keylen = sizeof(key);
+       if (kr_straddr_join(addr, port, key, &keylen) != kr_ok()) {
+               return kr_error(EINVAL);
+       }
+
+       struct tls_client_paramlist_entry *entry = map_get(tls_client_paramlist, key);
+       if (entry != NULL) {
+               client_paramlist_entry_unref(entry);
+               map_del(tls_client_paramlist, key);
+       }
+
+       return kr_ok();
+}
+
 int tls_client_params_set(map_t *tls_client_paramlist,
                          const char *addr, uint16_t port,
                          const char *param, tls_client_param_t param_type)
@@ -645,6 +833,7 @@ int tls_client_params_set(map_t *tls_client_paramlist,
                        return kr_error(ENOMEM);
                }
                gnutls_certificate_set_verify_function(entry->credentials, client_verify_certificate);
+               client_paramlist_entry_ref(entry);
        }
 
        int ret = kr_ok();
@@ -744,7 +933,7 @@ int tls_client_params_set(map_t *tls_client_paramlist,
        }
 
        if ((ret != kr_ok()) && is_first_entry) {
-               client_paramlist_entry_clear(NULL, (void *)entry, NULL);
+               client_paramlist_entry_unref(entry);
        }
 
        return ret;
@@ -879,7 +1068,7 @@ skip_pins:
        return GNUTLS_E_CERTIFICATE_ERROR;
 }
 
-struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_entry *entry,
+struct tls_client_ctx_t *tls_client_ctx_new(struct tls_client_paramlist_entry *entry,
                                            struct worker_ctx *worker)
 {
        struct tls_client_ctx_t *ctx = calloc(1, sizeof (struct tls_client_ctx_t));
@@ -887,7 +1076,7 @@ struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_en
                return NULL;
        }
 
-       int ret = gnutls_init(&ctx->c.tls_session, GNUTLS_CLIENT | GNUTLS_NONBLOCK);
+       int ret = gnutls_init(&ctx->c.tls_session, GNUTLS_CLIENT | GNUTLS_NONBLOCK | GNUTLS_ENABLE_FALSE_START);
        if (ret != GNUTLS_E_SUCCESS) {
                tls_client_ctx_free(ctx);
                return NULL;
@@ -899,6 +1088,11 @@ struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_en
                return NULL;
        }
 
+       /* Must take a reference on parameters as the credentials are owned by it
+        * and must not be freed while the session is active. */
+       client_paramlist_entry_ref(entry);
+       ctx->params = entry;
+
        ret = gnutls_credentials_set(ctx->c.tls_session, GNUTLS_CRD_CERTIFICATE,
                                     entry->credentials);
        if (ret != GNUTLS_E_SUCCESS) {
@@ -910,7 +1104,7 @@ struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_en
        ctx->c.client_side = true;
 
        gnutls_transport_set_pull_function(ctx->c.tls_session, kres_gnutls_pull);
-       gnutls_transport_set_push_function(ctx->c.tls_session, worker_gnutls_push);
+       gnutls_transport_set_vec_push_function(ctx->c.tls_session, kres_gnutls_vec_push);
        gnutls_transport_set_ptr(ctx->c.tls_session, ctx);
        return ctx;
 }
@@ -926,6 +1120,9 @@ void tls_client_ctx_free(struct tls_client_ctx_t *ctx)
                ctx->c.tls_session = NULL;
        }
 
+       /* Must decrease the refcount for referenced parameters */
+       client_paramlist_entry_unref(ctx->params);
+
        free (ctx);
 }
 
@@ -942,7 +1139,7 @@ int tls_client_connect_start(struct tls_client_ctx_t *client_ctx,
        struct tls_common_ctx *ctx = &client_ctx->c;
 
        gnutls_session_set_ptr(ctx->tls_session, client_ctx);
-       gnutls_handshake_set_timeout(ctx->tls_session, KR_CONN_RTT_MAX * 3);
+       gnutls_handshake_set_timeout(ctx->tls_session, ctx->worker->engine->net.tcp.tls_handshake_timeout);
        session->tls_client_ctx = client_ctx;
        ctx->handshake_cb = handshake_cb;
        ctx->handshake_state = TLS_HS_IN_PROGRESS;
@@ -954,14 +1151,15 @@ int tls_client_connect_start(struct tls_client_ctx_t *client_ctx,
                                        tls_params->session_data.size);
        }
 
-       int ret = gnutls_handshake(ctx->tls_session);
-       if (ret == GNUTLS_E_SUCCESS) {
-               return kr_ok();
-       } else if (gnutls_error_is_fatal(ret) != 0) {
-               kr_log_verbose("[tls_client] handshake failed (%s)\n", gnutls_strerror(ret));
-               return kr_error(ECONNABORTED);
+       /* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
+       while (ctx->handshake_state <= TLS_HS_IN_PROGRESS) {
+               /* Don't pass the handshake callback as the connection isn't registered yet. */
+               int ret = tls_handshake(ctx, NULL);
+               if (ret != kr_ok()) {
+                       return ret;
+               }
        }
-       return kr_error(EAGAIN);
+       return kr_ok();
 }
 
 tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx)
@@ -978,14 +1176,11 @@ int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state)
        return kr_ok();
 }
 
-int tls_client_ctx_set_params(struct tls_client_ctx_t *ctx,
-                             struct tls_client_paramlist_entry *entry,
-                             struct session *session)
+int tls_client_ctx_set_session(struct tls_client_ctx_t *ctx, struct session *session)
 {
        if (!ctx) {
                return kr_error(EINVAL);
        }
-       ctx->params = entry;
        ctx->c.session = session;
        return kr_ok();
 }
index c5c01c787520741a2dd8f0b278fe599aaad52081..ce13d20bb21799b52db0dfb4b66d5a896bcfbe4d 100644 (file)
@@ -42,6 +42,9 @@
  */
  #define TLS_MAX_HANDSHAKE_TIME (KR_CONN_RTT_MAX * 3)
 
+/** Transport session (opaque). */
+struct session;
+
 struct tls_ctx_t;
 struct tls_client_ctx_t;
 struct tls_credentials {
@@ -59,6 +62,7 @@ struct tls_client_paramlist_entry {
        array_t(const char *) pins;
        gnutls_certificate_credentials_t credentials;
        gnutls_datum_t session_data;
+       uint32_t refs;
 };
 
 struct worker_ctx;
@@ -93,7 +97,6 @@ struct tls_common_ctx {
        uint8_t recv_buf[4096];
        tls_handshake_cb handshake_cb;
        struct worker_ctx *worker;
-       struct qr_task *task;
 };
 
 struct tls_ctx_t {
@@ -126,7 +129,7 @@ void tls_close(struct tls_common_ctx *ctx);
 void tls_free(struct tls_ctx_t* tls);
 
 /*! Push new data to TLS context for sending */
-int tls_push(struct qr_task *task, uv_handle_t* handle, knot_pkt_t * pkt);
+int tls_write(uv_write_t *req, uv_handle_t* handle, knot_pkt_t * pkt, uv_write_cb cb);
 
 /*! Unwrap incoming data from a TLS stream and pass them to TCP session.
  * @return the number of newly-completed requests (>=0) or an error code
@@ -158,6 +161,15 @@ tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx);
 /*! Set TLS handshake state. */
 int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state);
 
+/*! Find TLS parameters for given address. Attempt opportunistic upgrade for port 53 to 853,
+ *  if the address is configured with a working DoT on port 853.
+ */
+struct tls_client_paramlist_entry *tls_client_try_upgrade(map_t *tls_client_paramlist,
+                         const struct sockaddr *addr);
+
+/*! Clear (remove) TLS parameters for given address. */
+int tls_client_params_clear(map_t *tls_client_paramlist, const char *addr, uint16_t port);
+
 /*! Set TLS authentication parameters for given address.
  * Note: hostnames must be imported before ca files,
  *       otherwise ca files will not be imported at all.
@@ -170,7 +182,7 @@ int tls_client_params_set(map_t *tls_client_paramlist,
 int tls_client_params_free(map_t *tls_client_paramlist);
 
 /*! Allocate new client TLS context */
-struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_entry *entry,
+struct tls_client_ctx_t *tls_client_ctx_new(struct tls_client_paramlist_entry *entry,
                                            struct worker_ctx *worker);
 
 /*! Free client TLS context */
@@ -180,9 +192,7 @@ int tls_client_connect_start(struct tls_client_ctx_t *client_ctx,
                             struct session *session,
                             tls_handshake_cb handshake_cb);
 
-int tls_client_ctx_set_params(struct tls_client_ctx_t *ctx,
-                             struct tls_client_paramlist_entry *entry,
-                             struct session *session);
+int tls_client_ctx_set_session(struct tls_client_ctx_t *ctx, struct session *session);
 
 
 /* Session tickets, server side.  Implementation in ./tls_session_ticket-srv.c */
index 850d02f560b7239e1c8a20090d092e7e46aa070d..5a6e2e58357a3018fafa68383ac75b99c053fc10 100644 (file)
@@ -217,7 +217,15 @@ static uv_handle_t *ioreq_spawn(struct qr_task *task, int socktype, sa_family_t
        if (!handle) {
                return NULL;
        }
-       io_create(worker->loop, handle, socktype);
+       int ret = io_create(worker->loop, handle, socktype, family);
+       if (ret) {
+               if (ret == UV_EMFILE) {
+                       worker->too_many_open = true;
+                       worker->rconcurrent_highwatermark = worker->stats.rconcurrent;
+               }
+               iohandle_release(worker, h);
+               return NULL;
+       }
 
        /* Bind to outgoing address, according to IP v4/v6. */
        union inaddr *addr;
@@ -226,7 +234,6 @@ static uv_handle_t *ioreq_spawn(struct qr_task *task, int socktype, sa_family_t
        } else {
                addr = (union inaddr *)&worker->out_addr6;
        }
-       int ret = 0;
        if (addr->ip.sa_family != AF_UNSPEC) {
                assert(addr->ip.sa_family == family);
                if (socktype == SOCK_DGRAM) {
@@ -900,86 +907,6 @@ static void on_task_write(uv_write_t *req, int status)
        iorequest_release(worker, req);
 }
 
-static void on_nontask_write(uv_write_t *req, int status)
-{
-       uv_handle_t *handle = (uv_handle_t *)(req->handle);
-       uv_loop_t *loop = handle->loop;
-       struct worker_ctx *worker = loop->data;
-       assert(worker == get_worker());
-       iorequest_release(worker, req);
-}
-
-ssize_t worker_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
-{
-       struct tls_common_ctx *t = (struct tls_common_ctx *)h;
-       const uv_buf_t uv_buf[1] = {
-               { (char *)buf, len }
-       };
-
-       if (t == NULL) {
-               errno = EFAULT;
-               return -1;
-       }
-
-       assert(t->session && t->session->handle &&
-              t->session->handle->type == UV_TCP);
-
-       VERBOSE_MSG(NULL,"[%s] push %zu <%p>\n",
-                   t->client_side ? "tls_client" : "tls", len, h);
-
-       struct worker_ctx *worker = t->worker;
-       assert(worker);
-
-       void *ioreq = worker_iohandle_borrow(worker);
-       if (!ioreq) {
-               errno = EFAULT;
-               return -1;
-       }
-
-       uv_write_t *write_req = (uv_write_t *)ioreq;
-
-       struct qr_task *task = t->task;
-       uv_write_cb write_cb = on_task_write;
-       if (t->handshake_state == TLS_HS_DONE) {
-               assert(task);
-       } else {
-               task = NULL;
-               write_cb = on_nontask_write;
-       }
-
-       write_req->data = task;
-
-       ssize_t ret = -1;
-       int res = uv_write(write_req, (uv_stream_t *)t->session->handle, uv_buf, 1, write_cb);
-       if (res == 0) {
-               if (task) {
-                       qr_task_ref(task); /* Pending ioreq on current task */
-                       struct request_ctx *ctx = task->ctx;
-                       if (ctx && ctx->source.session &&
-                           t->session->handle != ctx->source.session->handle) {
-                               struct sockaddr *addr = &t->session->peer.ip;
-                               worker->stats.tls += 1;
-                               if (addr->sa_family == AF_INET6)
-                                       worker->stats.ipv6 += 1;
-                               else if (addr->sa_family == AF_INET)
-                                       worker->stats.ipv4 += 1;
-                       }
-               }
-               if (worker->too_many_open &&
-                   worker->stats.rconcurrent <
-                       worker->rconcurrent_highwatermark - 10) {
-                       worker->too_many_open = false;
-               }
-               ret = len;
-       } else {
-               VERBOSE_MSG(NULL,"[%s] uv_write: %s\n",
-                           t->client_side ? "tls_client" : "tls", uv_strerror(res));
-               iorequest_release(worker, ioreq);
-               errno = EIO;
-       }
-       return ret;
-}
-
 static int qr_task_send(struct qr_task *task, uv_handle_t *handle,
                        struct sockaddr *addr, knot_pkt_t *pkt)
 {
@@ -987,21 +914,6 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle,
                return qr_task_on_send(task, handle, kr_error(EIO));
        }
 
-       /* Synchronous push to TLS context, bypassing event loop. */
-       struct session *session = handle->data;
-       assert(session->closing == false);
-       if (session->has_tls) {
-               struct kr_request *req = &task->ctx->req;
-               if (session->outgoing) {
-                       int ret = kr_resolve_checkout(req, NULL, addr,
-                                                     SOCK_STREAM, pkt);
-                       if (ret != kr_ok()) {
-                               return ret;
-                       }
-               }
-               return tls_push(task, handle, pkt);
-       }
-
        int ret = 0;
        struct request_ctx *ctx = task->ctx;
        struct worker_ctx *worker = ctx->worker;
@@ -1031,8 +943,18 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle,
                        return ret;
                }
        }
+
+       /* Pending ioreq on current task */
+       qr_task_ref(task);
+
        /* Send using given protocol */
-       if (handle->type == UV_UDP) {
+       struct session *session = handle->data;
+       assert(session->closing == false);
+       if (session->has_tls) {
+               uv_write_t *write_req = (uv_write_t *)ioreq;
+               write_req->data = task;
+               ret = tls_write(write_req, handle, pkt, &on_task_write);
+       } else if (handle->type == UV_UDP) {
                uv_udp_send_t *send_req = (uv_udp_send_t *)ioreq;
                uv_buf_t buf = { (char *)pkt->wire, pkt->size };
                send_req->data = task;
@@ -1051,7 +973,6 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle,
        }
 
        if (ret == 0) {
-               qr_task_ref(task); /* Pending ioreq on current task */
                if (worker->too_many_open &&
                    worker->stats.rconcurrent <
                        worker->rconcurrent_highwatermark - 10) {
@@ -1059,6 +980,7 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle,
                }
        } else {
                iorequest_release(worker, ioreq);
+               qr_task_unref(task);
                if (ret == UV_EMFILE) {
                        worker->too_many_open = true;
                        worker->rconcurrent_highwatermark = worker->stats.rconcurrent;
@@ -1069,15 +991,19 @@ static int qr_task_send(struct qr_task *task, uv_handle_t *handle,
        if (ctx->source.session &&
            handle != ctx->source.session->handle &&
            addr) {
-               if (handle->type == UV_UDP)
+               if (session->has_tls)
+                       worker->stats.tls += 1;
+               else if (handle->type == UV_UDP)
                        worker->stats.udp += 1;
                else
                        worker->stats.tcp += 1;
+
                if (addr->sa_family == AF_INET6)
                        worker->stats.ipv6 += 1;
                else if (addr->sa_family == AF_INET)
                        worker->stats.ipv4 += 1;
        }
+
        return ret;
 }
 
@@ -1544,12 +1470,16 @@ static int qr_task_finalize(struct qr_task *task, int state)
        }
        struct request_ctx *ctx = task->ctx;
        kr_resolve_finish(&ctx->req, state);
+
        task->finished = true;
        if (ctx->source.session == NULL) {
                (void) qr_task_on_send(task, NULL, kr_error(EIO));
                return state == KR_STATE_DONE ? 0 : kr_error(EIO);
        }
 
+       /* Reference task as the callback handler can close it */
+       qr_task_ref(task);
+
        /* Send back answer */
        struct session *source_session = ctx->source.session;
        uv_handle_t *handle = source_session->handle;
@@ -1573,12 +1503,14 @@ static int qr_task_finalize(struct qr_task *task, int state)
                        session_del_tasks(source_session, t);
                }
                session_close(source_session);
-       } else if (handle->type == UV_TCP) {
+       } else if (handle->type == UV_TCP && ctx->source.session) {
                /* Don't try to close source session at least
                 * retry_interval_for_timeout_timer milliseconds */
                uv_timer_again(&ctx->source.session->timeout);
        }
 
+       qr_task_unref(task);
+
        return state == KR_STATE_DONE ? 0 : kr_error(EIO);
 }
 
@@ -1820,7 +1752,7 @@ static int qr_task_step(struct qr_task *task,
                                        subreq_finalize(task, packet_source, packet);
                                        return qr_task_step(task, NULL, NULL);
                                }
-                               tls_client_ctx_set_params(tls_ctx, entry, session);
+                               tls_client_ctx_set_session(tls_ctx, session);
                                session->tls_client_ctx = tls_ctx;
                                session->has_tls = true;
                        }
index 0d1bcc20825ff82ef65f2ab06c803f5e5170208e..3acecfd0eab6721f1f5fcc5a8b1cfcda01f1c00f 100644 (file)
@@ -16,8 +16,6 @@
 
 #pragma once
 
-#include <gnutls/gnutls.h>
-
 #include "daemon/engine.h"
 #include "lib/generic/array.h"
 #include "lib/generic/map.h"
@@ -92,10 +90,6 @@ void *worker_iohandle_borrow(struct worker_ctx *worker);
 
 void worker_iohandle_release(struct worker_ctx *worker, void *h);
 
-ssize_t worker_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len);
-
-ssize_t worker_gnutls_client_push(gnutls_transport_ptr_t h, const void *buf, size_t len);
-
 /** Finalize given task */
 int worker_task_finalize(struct qr_task *task, int state);
 
index 38ba39d55ab7dce00974e498f8d9a7817c3b3060..678cec22b253c69d1c119f7349210f1c5dd110a5 100644 (file)
@@ -410,6 +410,18 @@ uint16_t kr_inaddr_port(const struct sockaddr *addr)
        }
 }
 
+void kr_inaddr_set_port(struct sockaddr *addr, uint16_t port)
+{
+       if (!addr) {
+               return;
+       }
+       switch (addr->sa_family) {
+       case AF_INET:  ((struct sockaddr_in *)addr)->sin_port = htons(port);
+       case AF_INET6: ((struct sockaddr_in6 *)addr)->sin6_port = htons(port);
+       default: break;
+       }
+}
+
 int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
 {
        int ret = kr_ok();
index d5806c284793f1faae0f0236454fa3523f2ccc8c..1e6cb16ff719651e0c3481ad4253f2de856673ec 100644 (file)
@@ -247,6 +247,9 @@ int kr_sockaddr_cmp(const struct sockaddr *left, const struct sockaddr *right);
 /** Port. */
 KR_EXPORT KR_PURE
 uint16_t kr_inaddr_port(const struct sockaddr *addr);
+/** Set port. */
+KR_EXPORT
+void kr_inaddr_set_port(struct sockaddr *addr, uint16_t port);
 /** String representation for given address as "<addr>#<port>" */
 KR_EXPORT
 int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen);