]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: bugfixes & improvements
authorGrigorii Demidov <grigorii.demidov@nic.cz>
Wed, 3 Oct 2018 12:48:23 +0000 (14:48 +0200)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Fri, 12 Oct 2018 15:36:46 +0000 (17:36 +0200)
daemon/io.c
daemon/session.c
daemon/session.h
daemon/tls.c
daemon/tls.h
daemon/worker.c

index 8d6f1532622e780e735fe1286605b892e33b283f..496d43a9d7f9cadf741c7b3ed842161aa712f42e 100644 (file)
@@ -195,7 +195,7 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
                consumed = tls_process_input_data(s, (const uint8_t *)buf->base, nread);
                data = session_wirebuf_get_free_start(s);
                data_len = consumed;
-       } 
+       }
 
        /* data points to start of the free space in session wire buffer.
           Simple increase internal counter. */
index 1dbe6f670f046ab5075adeafe647da8847088375..0ae93c62035070876a349d6b62654c077a7624ba 100644 (file)
@@ -108,6 +108,11 @@ int session_start_read(struct session *session)
        return io_start_read(session->handle);
 }
 
+int session_stop_read(struct session *session)
+{
+       return io_stop_read(session->handle);
+}
+
 int session_waitinglist_push(struct session *session, struct qr_task *task)
 {
        queue_push(session->waiting, task);
@@ -142,7 +147,7 @@ int session_tasklist_add(struct session *session, struct qr_task *task)
                key = (const char *)&task_msg_id;
                key_len = sizeof(task_msg_id);
        } else {
-               key = (const char *)task;
+               key = (const char *)&task;
                key_len = sizeof(task);
        }
        trie_val_t *v = trie_get_ins(t, key, key_len);
@@ -173,7 +178,7 @@ int session_tasklist_del(struct session *session, struct qr_task *task)
                key = (const char *)&task_msg_id;
                key_len = sizeof(task_msg_id);
        } else {
-               key = (const char *)task;
+               key = (const char *)&task;
                key_len = sizeof(task);
        }
        int ret = trie_del(t, key, key_len, &val);
index 182cabec3646d40d99e358c7817d57e7d55b8564..c0f68039a6e93ecd0c1f4a4269ef9c6e1bb239fb 100644 (file)
@@ -42,6 +42,8 @@ void session_clear(struct session *session);
 void session_close(struct session *session);
 /** Start reading from underlying libuv IO handle. */
 int session_start_read(struct session *session);
+/** Stop reading from underlying libuv IO handle. */
+int session_stop_read(struct session *session);
 
 /** List of tasks been waiting for IO. */
 /** Check if list is empty. */
index f73e7d74b51cb058e80c9abe963e3295ada9d01f..3ffbb1595d954aadefc0b320d663efda52e2118c 100644 (file)
 #define DEBUG_MSG(fmt...)
 #endif
 
+struct async_write_ctx {
+       uv_write_t write_req;
+       struct tls_common_ctx *t;
+       char buf[0];
+};
+
 static char const server_logstring[] = "tls";
 static char const client_logstring[] = "tls_client";
 
@@ -94,18 +100,16 @@ static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
 static void on_write_complete(uv_write_t *req, int status)
 {
        assert(req->data != NULL);
+       struct async_write_ctx *async_ctx = (struct async_write_ctx *)req->data;
+       struct tls_common_ctx *t = async_ctx->t;
+       assert(t->write_queue_size);
+       t->write_queue_size -= 1;
        free(req->data);
-       free(req);
 }
 
-static bool stream_queue_is_empty(uv_stream_t *handle)
+static bool stream_queue_is_empty(struct tls_common_ctx *t)
 {
-#if UV_VERSION_HEX >= 0x011900
-       return uv_stream_get_write_queue_size(handle) == 0;
-#else
-       /* Assume best case */
-       return true;
-#endif
+       return (t->write_queue_size == 0);
 }
 
 static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * iov, int iovcnt)
@@ -144,7 +148,7 @@ static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * i
 
        /* Try to perform the immediate write first to avoid copy */
        int ret = 0;
-       if (stream_queue_is_empty(handle)) {
+       if (stream_queue_is_empty(t)) {
                ret = uv_try_write(handle, uv_buf, iovcnt);
                DEBUG_MSG("[%s] push %zu <%p> = %d\n",
                    t->client_side ? "tls_client" : "tls", total_len, h, ret);
@@ -153,12 +157,19 @@ static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * i
                     > 0: number of bytes written (can be less than the supplied buffer size).
                     < 0: negative error code (UV_EAGAIN is returned if no data can be sent immediately).
                */
-               if ((ret == total_len) || (ret < 0 && ret != UV_EAGAIN)) {
-                       /* Either all the data were buffered by libuv or
-                        * uv_try_write() has returned error code other then UV_EAGAIN.
+               if (ret == total_len) {
+                       /* All the data were buffered by libuv.
                         * Return. */
                        return ret;
                }
+
+               if (ret < 0 && ret != UV_EAGAIN) {
+                       /* uv_try_write() has returned error code other then UV_EAGAIN.
+                        * Return. */
+                       ret = -1;
+                       errno = EIO;
+                       return ret;
+               }
                /* Since we are here expression below is true
                 * (ret != total_len) && (ret >= 0 || ret == UV_EAGAIN)
                 * or the same
@@ -173,10 +184,14 @@ static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * i
        }
 
        /* Fallback when the queue is full, and it's not possible to do an immediate write */
-       char *buf = malloc(total_len - ret);
-       if (buf != NULL) {
+       char *p = malloc(sizeof(struct async_write_ctx) + total_len - ret);
+       if (p != NULL) {
+               struct async_write_ctx *async_ctx = (struct async_write_ctx *)p;
+               /* Save pointer to session tls context */
+               async_ctx->t = t;
+               char *buf = async_ctx->buf;
                /* Skip data written in the partial write */
-               int to_skip = ret;
+               size_t to_skip = ret;
                /* Copy the buffer into owned memory */
                size_t off = 0;
                for (int i = 0; i < iovcnt; ++i) {
@@ -198,21 +213,16 @@ static ssize_t kres_gnutls_vec_push(gnutls_transport_ptr_t h, const giovec_t * i
                uv_buf[0].len = off;
 
                /* Create an asynchronous write request */
-               uv_write_t *write_req = calloc(1, sizeof(uv_write_t));
-               if (write_req != NULL) {
-                       write_req->data = buf;
-               } else {
-                       free(buf);
-                       errno = ENOMEM;
-                       return -1;
-               }
+               uv_write_t *write_req = &async_ctx->write_req;
+               memset(write_req, 0, sizeof(uv_write_t));
+               write_req->data = p;
 
                /* Perform an asynchronous write with a callback */
                if (uv_write(write_req, handle, uv_buf, 1, on_write_complete) == 0) {
                        ret = total_len;
+                       t->write_queue_size += 1;
                } else {
-                       free(buf);
-                       free(write_req);
+                       free(p);
                        errno = EIO;
                        ret = -1;
                }
@@ -410,10 +420,14 @@ int tls_write(uv_write_t *req, uv_handle_t *handle, knot_pkt_t *pkt, uv_write_cb
        const ssize_t submitted = sizeof(pkt_size) + pkt->size;
 
        int ret = gnutls_record_uncork(tls_session, GNUTLS_RECORD_WAIT);
-       if (gnutls_error_is_fatal(ret)) {
-               kr_log_error("[%s] gnutls_record_uncork failed: %s (%d)\n",
-                            logstring, gnutls_strerror_name(ret), ret);
-               return kr_error(EIO);
+       if (ret < 0) {
+               if (!gnutls_error_is_fatal(ret)) {
+                       return kr_error(EAGAIN);
+               } else {
+                       kr_log_error("[%s] gnutls_record_uncork failed: %s (%d)\n",
+                                    logstring, gnutls_strerror_name(ret), ret);
+                       return kr_error(EIO);
+               }
        }
 
        if (ret != submitted) {
index 1bfa6ef6de512fd3de250e827b46bf45f4f5b18a..cb3d4a64f1e79b3f7f42254d3416a7991c54ed4f 100644 (file)
@@ -94,9 +94,10 @@ struct tls_common_ctx {
        const uint8_t *buf;
        ssize_t nread;
        ssize_t consumed;
-       uint8_t recv_buf[4096];
+       uint8_t recv_buf[8192];
        tls_handshake_cb handshake_cb;
        struct worker_ctx *worker;
+       size_t write_queue_size;
 };
 
 struct tls_ctx_t {
index c2d2fc3b991aee00465c5c03ed3399108eb9685b..ea74dc0a2975794723abc7b136ef23b5c7ddf446 100644 (file)
@@ -488,7 +488,7 @@ static void qr_task_free(struct qr_task *task)
 
        /* Process source session. */
        if (s && session_tasklist_get_len(s) < worker->tcp_pipeline_max/2 &&
-           !session_flags(s)->closing && !session_flags(s)->throttled) {
+           !session_flags(s)->closing && session_flags(s)->throttled) {
                uv_handle_t *handle = session_get_handle(s);
                /* Start reading again if the session is throttled and
                 * the number of outgoing requests is below watermark. */
@@ -522,12 +522,10 @@ static int qr_task_register(struct qr_task *task, struct session *session)
         * an in effect shrink TCP window size. To get more precise throttling,
         * we would need to copy remainder of the unread buffer and reassemble
         * when resuming reading. This is NYI.  */
-       if (session_tasklist_get_len(session) >= task->ctx->worker->tcp_pipeline_max) {
-               uv_handle_t *handle = session_get_handle(session);
-               if (handle && !session_flags(session)->throttled && !session_flags(session)->closing) {
-                       io_stop_read(handle);
-                       session_flags(session)->throttled = true;
-               }
+       if (session_tasklist_get_len(session) >= task->ctx->worker->tcp_pipeline_max &&
+           !session_flags(session)->throttled && !session_flags(session)->closing) {
+               session_stop_read(session);
+               session_flags(session)->throttled = true;
        }
 
        return 0;
@@ -555,32 +553,35 @@ static void qr_task_complete(struct qr_task *task)
 /* This is called when we send subrequest / answer */
 static int qr_task_on_send(struct qr_task *task, uv_handle_t *handle, int status)
 {
+
        if (task->finished) {
                assert(task->leading == false);
                qr_task_complete(task);
-               if (!handle || handle->type != UV_TCP) {
-                       return status;
-               }
-               struct session* s = handle->data;
-               assert(s);
-               if (!session_flags(s)->outgoing || session_waitinglist_is_empty(s)) {
-                       return status;
-               }
        }
 
-       if (handle) {
-               struct session* s = handle->data;
-               bool outgoing = session_flags(s)->outgoing;
-               if (!outgoing) {
-                       struct session* source_s = task->ctx->source.session;
-                       if (source_s) {
-                               assert (session_get_handle(source_s) == handle);
-                       }
-               }
-               if (!session_flags(s)->closing) {
-                       io_start_read(handle); /* Start reading new query */
-               }
+       if (!handle || handle->type != UV_TCP) {
+               return status;
+       }
+
+       struct session* s = handle->data;
+       assert(s);
+       if (status != 0) {
+               session_tasklist_del(s, task);
+       }
+
+       if (session_flags(s)->outgoing || session_flags(s)->closing) {
+               return status;
+       }
+
+       struct worker_ctx *worker = task->ctx->worker;
+       if (session_flags(s)->throttled &&
+           session_tasklist_get_len(s) < worker->tcp_pipeline_max/2) {
+          /* Start reading again if the session is throttled and
+           * the number of outgoing requests is below watermark. */
+               session_start_read(s);
+               session_flags(s)->throttled = false;
        }
+
        return status;
 }
 
@@ -629,14 +630,14 @@ static int qr_task_send(struct qr_task *task, struct session *session,
        if (session_flags(session)->outgoing) {
                size_t try_limit = session_tasklist_get_len(session) + 1;
                uint16_t msg_id = knot_wire_get_id(pkt->wire);
-               int try_count = 0;
+               size_t try_count = 0;
                while (session_tasklist_find_msgid(session, msg_id) &&
                       try_count <= try_limit) {
                        ++msg_id;
                        ++try_count;
                }
                if (try_count > try_limit) {
-                       return qr_task_on_send(task, handle, kr_error(EIO));
+                       return kr_error(ENOENT);
                }
                worker_task_pkt_set_msgid(task, msg_id);
        }
@@ -867,13 +868,13 @@ static void on_connect(uv_connect_t *req, int status)
        }
 
        session_flags(session)->connected = true;
+       session_start_read(session);
 
        int ret = kr_ok();
        if (session_flags(session)->has_tls) {
                struct tls_client_ctx_t *tls_ctx = session_tls_get_client_ctx(session);
                ret = tls_client_connect_start(tls_ctx, session, session_tls_hs_cb);
                if (ret == kr_error(EAGAIN)) {
-                       session_start_read(session);
                        session_timer_start(session, on_tcp_watchdog_timeout,
                                            MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
                        return;
@@ -886,7 +887,6 @@ static void on_connect(uv_connect_t *req, int status)
                ret = qr_task_send(t, session, NULL, NULL);
                if (ret != 0) {
                        assert(session_tasklist_is_empty(session));
-                       assert(false);
                        worker_del_tcp_connected(worker, peer);
                        session_waitinglist_finalize(session, KR_STATE_FAIL);
                        session_close(session);
@@ -894,6 +894,8 @@ static void on_connect(uv_connect_t *req, int status)
                }
                session_waitinglist_pop(session, true);
        }
+       session_timer_start(session, on_tcp_watchdog_timeout,
+                           MAX_TCP_INACTIVITY, MAX_TCP_INACTIVITY);
 }
 
 static void on_tcp_connect_timeout(uv_timer_t *timer)
@@ -1012,6 +1014,7 @@ static uv_handle_t *retransmit(struct qr_task *task)
                        task->pending_count += 1;
                        task->addrlist_turn = (task->addrlist_turn + 1) %
                                              task->addrlist_count; /* Round robin */
+                       session_start_read(session); /* Start reading answer */
                }
        }
        return ret;
@@ -1180,9 +1183,11 @@ static int qr_task_step(struct qr_task *task,
                if (worker->stats.rconcurrent <
                        worker->rconcurrent_highwatermark - 10) {
                        worker->too_many_open = false;
-               } else if (packet && kr_rplan_empty(rplan)) {
-                       /* new query; TODO - make this detection more obvious */
-                       kr_resolve_consume(req, packet_source, packet);
+               } else {
+                       if (packet && kr_rplan_empty(rplan)) {
+                               /* new query; TODO - make this detection more obvious */
+                               kr_resolve_consume(req, packet_source, packet);
+                       }
                        return qr_task_finalize(task, KR_STATE_FAIL);
                }
        }
@@ -1294,6 +1299,11 @@ static int qr_task_step(struct qr_task *task,
                                session_waitinglist_pop(session, true);
                        }
 
+                       if (session_tasklist_get_len(session) >= worker->tcp_pipeline_max) {
+                               subreq_finalize(task, packet_source, packet);
+                               return qr_task_finalize(task, KR_STATE_FAIL);
+                       }
+
                        ret = qr_task_send(task, session, NULL, NULL);
                        if (ret != 0 /* && ret != kr_error(EMFILE) */) {
                                session_tasklist_finalize(session, KR_STATE_FAIL);
@@ -1753,6 +1763,8 @@ void worker_task_pkt_set_msgid(struct qr_task *task, uint16_t msgid)
 {
        knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
        knot_wire_set_id(pktbuf->wire, msgid);
+       struct kr_query *q = task_get_last_pending_query(task);
+       q->id = msgid;
 }
 
 /** Reserve worker buffers */