]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: optimizations and logic fixes
authorOto Šťáva <oto.stava@nic.cz>
Thu, 9 Feb 2023 07:27:11 +0000 (08:27 +0100)
committerOto Šťáva <oto.stava@nic.cz>
Thu, 9 Feb 2023 07:27:11 +0000 (08:27 +0100)
daemon/session2.c
daemon/session2.h
daemon/tls.c
daemon/worker.c

index c9e9e5945d710ffc3b6c4d38649fc749fc280ed9..b7925d920e9aa6733918c99a5cd15453f9e966a3 100644 (file)
@@ -218,6 +218,27 @@ size_t protolayer_queue_count_payload(const protolayer_iter_ctx_queue_t *queue)
        return sum;
 }
 
+bool protolayer_queue_has_payload(const protolayer_iter_ctx_queue_t *queue)
+{
+       if (!queue || queue_len(*queue) == 0)
+               return false;
+
+       /* We're only reading from the queue, but we need to discard the
+        * `const` so that `queue_it_begin()` accepts it. As long as
+        * `queue_it_` operations do not write into the queue (which they do
+        * not, checked at the time of writing), we should be safely in the
+        * defined behavior territory. */
+       queue_it_t(struct protolayer_iter_ctx *) it =
+               queue_it_begin(*(protolayer_iter_ctx_queue_t *)queue);
+       for (; !queue_it_finished(it); queue_it_next(it)) {
+               struct protolayer_iter_ctx *ctx = queue_it_val(it);
+               if (protolayer_payload_size(&ctx->payload))
+                       return true;
+       }
+
+       return false;
+}
+
 
 /** Gets layer-specific session data for the layer with the specified index
  * from the manager. */
index d3c94544c89d3326e4cd98badaf0f110c89067e1..213a7d93aea06ae3aa8c87ad420cf2a405c50bb8 100644 (file)
@@ -412,6 +412,11 @@ typedef queue_t(struct protolayer_iter_ctx *) protolayer_iter_ctx_queue_t;
  * available in it. */
 size_t protolayer_queue_count_payload(const protolayer_iter_ctx_queue_t *queue);
 
+/** Checks if the specified `queue` has any payload data (i.e.
+ * `protolayer_queue_count_payload` would be non-zero). This optimizes calls to
+ * queue iterators, as it does not need to iterate through the whole queue. */
+bool protolayer_queue_has_payload(const protolayer_iter_ctx_queue_t *queue);
+
 /** Mandatory header members for any layer-specific data. */
 #define PROTOLAYER_DATA_HEADER() struct {\
        struct session2 *session; /**< Pointer to the owner session. */\
index 56bcad504be4d8ffb0945f46ce5b545d917aab18..97dd3cb1d825bc754fc46fde2e186d8d07119dcf 100644 (file)
@@ -104,11 +104,11 @@ static ssize_t kres_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
                return -1;
        }
 
-       size_t avail = protolayer_queue_count_payload(&tls->unwrap_queue);
-       DEBUG_MSG("[%s] pull wanted: %zu avail: %zu\n",
+       bool avail = protolayer_queue_has_payload(&tls->unwrap_queue);
+       DEBUG_MSG("[%s] pull wanted: %zu avail: %s\n",
                        tls->client_side ? "tls_client" : "tls",
-                       len, avail);
-       if (avail == 0) {
+                       len, avail ? "yes" : "no");
+       if (!avail) {
                errno = EAGAIN;
                return -1;
        }
@@ -1088,7 +1088,7 @@ static enum protolayer_iter_cb_result pl_tls_unwrap(void *sess_data, void *iter_
                                wire_buf_free_space(&tls->unwrap_buf),
                                wire_buf_free_space_length(&tls->unwrap_buf));
                if (count == GNUTLS_E_AGAIN) {
-                       if (protolayer_queue_count_payload(&tls->unwrap_queue) == 0) {
+                       if (!protolayer_queue_has_payload(&tls->unwrap_queue)) {
                                /* See https://www.gnutls.org/manual/html_node/Asynchronous-operation.html */
                                break;
                        }
@@ -1146,7 +1146,7 @@ static enum protolayer_iter_cb_result pl_tls_unwrap(void *sess_data, void *iter_
        }
 
        /* Here all data must be consumed. */
-       while (protolayer_queue_count_payload(&tls->unwrap_queue) > 0) {
+       if (protolayer_queue_has_payload(&tls->unwrap_queue)) {
                /* Something went wrong, better return error.
                 * This is most probably due to gnutls_record_recv() did not
                 * consume all available network data by calling kres_gnutls_pull().
@@ -1200,8 +1200,9 @@ static ssize_t pl_tls_submit(gnutls_session_t tls_session,
        return kr_error(EINVAL);
 }
 
-static enum protolayer_iter_cb_result pl_tls_wrap(void *sess_data, void *iter_data,
-                                             struct protolayer_iter_ctx *ctx)
+static enum protolayer_iter_cb_result pl_tls_wrap(
+               void *sess_data, void *iter_data,
+               struct protolayer_iter_ctx *ctx)
 {
        struct pl_tls_sess_data *tls = sess_data;
        gnutls_session_t tls_session = tls->tls_session;
index eb301cd0b64536460fb5b458ae90fd51f94d7555..6475b8c0fe22d108b034f61f74b15f6ee2d07098 100644 (file)
@@ -2041,16 +2041,36 @@ static enum protolayer_event_cb_result pl_dns_stream_event_unwrap(
        return PROTOLAYER_EVENT_PROPAGATE;
 }
 
-static knot_pkt_t *produce_stream_packet(struct wire_buf *wb)
+static knot_pkt_t *produce_stream_packet(struct session2 *session,
+                                         struct wire_buf *wb,
+                                         bool *out_err)
 {
+       *out_err = false;
+       if (wire_buf_data_length(wb) == 0) {
+               wire_buf_reset(wb);
+               return NULL;
+       }
+       if (wire_buf_data_length(wb) < sizeof(uint16_t)) {
+               return NULL;
+       }
+
        uint16_t pkt_len = knot_wire_read_u16(wire_buf_data(wb));
+       if (pkt_len == 0) {
+               *out_err = true;
+               return NULL;
+       }
+       if (pkt_len >= wb->size) {
+               *out_err = true;
+               return NULL;
+       }
        if (wire_buf_data_length(wb) < pkt_len + sizeof(uint16_t)) {
-               wire_buf_reset(wb);
                return NULL;
        }
 
+       session->was_useful = true;
        wire_buf_trim(wb, sizeof(uint16_t));
        knot_pkt_t *pkt = produce_packet(wire_buf_data(wb), pkt_len);
+       *out_err = (pkt == NULL);
        wire_buf_trim(wb, pkt_len);
        return pkt;
 }
@@ -2075,8 +2095,13 @@ static enum protolayer_iter_cb_result pl_dns_stream_unwrap(
                (KNOT_WIRE_HEADER_SIZE + KNOT_WIRE_QUESTION_MIN_SIZE)) + 1;
        int iters = 0;
 
-       knot_pkt_t *pkt;
-       while ((pkt = produce_stream_packet(wb)) && iters < max_iters) {
+       bool pkt_error = false;
+       knot_pkt_t *pkt = NULL;
+       while ((pkt = produce_stream_packet(session, wb, &pkt_error)) && iters < max_iters) {
+               if (kr_fails_assert(!pkt_error)) {
+                       status = kr_error(EINVAL);
+                       goto exit;
+               }
                if (stream_sess->single && stream_sess->produced) {
                        if (kr_log_is_debug(WORKER, NULL)) {
                                kr_log_debug(WORKER, "Unexpected extra data from %s\n",
@@ -2087,11 +2112,12 @@ static enum protolayer_iter_cb_result pl_dns_stream_unwrap(
                }
 
                stream_sess->produced = true;
-               if (pkt)
-                       session->was_useful = true;
 
                int ret = worker_submit(session, &ctx->comm, pkt);
                wire_buf_movestart(wb);
+
+               /* Errors from worker_submit() are intentionally *not* handled
+                * in order to ensure the entire wire buffer is processed. */
                if (ret == kr_ok()) {
                        iters += 1;
                }
@@ -2100,7 +2126,7 @@ static enum protolayer_iter_cb_result pl_dns_stream_unwrap(
        /* worker_submit() may cause the session to close (e.g. due to IO
         * write error when the packet triggers an immediate answer). This is
         * an error state, as well as any wirebuf error. */
-       if (session->closing)
+       if (session->closing || pkt_error)
                status = kr_error(EIO);
 
 exit: