]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: fix DoH not sending entire messages + nits
authorOto Šťáva <oto.stava@nic.cz>
Tue, 27 Sep 2022 08:06:01 +0000 (10:06 +0200)
committerOto Šťáva <oto.stava@nic.cz>
Thu, 26 Jan 2023 11:56:08 +0000 (12:56 +0100)
daemon/http.c
daemon/session2.c
daemon/session2.h
daemon/tls.c

index 27c19bfd036615a7365a2153fa687ca6247ab61c..2cf50d2e536dce3fceb88c96d52132bcee40c6ba 100644 (file)
@@ -233,6 +233,7 @@ static void http_cleanup_stream(struct pl_http_sess_data *ctx)
 {
        ctx->incomplete_stream = -1;
        ctx->current_method = HTTP_METHOD_NONE;
+       ctx->status = HTTP_STATUS_OK;
        free(ctx->uri_path);
        ctx->uri_path = NULL;
        http_free_headers(ctx->headers);
@@ -294,6 +295,8 @@ static int http_send_response(struct pl_http_sess_data *http, int32_t stream_id,
                max_age_len = asprintf(&max_age, "%s%" PRIu32, directive_max_age, ctx->payload.ttl);
                kr_require(max_age_len >= 0);
 
+               /* TODO: add a per-group option for content-type if we need to
+                * support protocols other than DNS here */
                push_nv(&hdrs, MAKE_STATIC_NV("content-type", "application/dns-message"));
                push_nv(&hdrs, MAKE_STATIC_KEY_NV("content-length", size, size_len));
                push_nv(&hdrs, MAKE_STATIC_KEY_NV("cache-control", max_age, max_age_len));
@@ -386,32 +389,60 @@ static int send_data_callback(nghttp2_session *h2, nghttp2_frame *frame, const u
 {
        struct pl_http_sess_data *http = user_data;
 
-/* I'm not yet sure if the below code is correct... the other one should be,
- * but it's probably considerably slower. */
-#if 1
        int has_padding = !!(frame->data.padlen);
        uint8_t padlen = (frame->data.padlen > 1) ? frame->data.padlen : 2;
 
        struct protolayer_iter_ctx *ctx = source->ptr;
-       struct protolayer_payload pld = ctx->payload;
+       struct protolayer_payload *pld = &ctx->payload;
 
        struct iovec bufiov;
        struct iovec *dataiov;
        int dataiovcnt;
-       if (pld.type == PROTOLAYER_PAYLOAD_BUFFER) {
-               bufiov = (struct iovec){ pld.buffer.buf, pld.buffer.len };
+       bool adapt_iovs = false;
+       if (pld->type == PROTOLAYER_PAYLOAD_BUFFER) {
+               size_t to_copy = MIN(length, pld->buffer.len);
+               if (!to_copy)
+                       return NGHTTP2_ERR_PAUSE;
+
+               bufiov = (struct iovec){ pld->buffer.buf, to_copy };
                dataiov = &bufiov;
                dataiovcnt = 1;
-       } else if (pld.type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
+
+               pld->buffer.buf = (char *)pld->buffer.buf + to_copy;
+               pld->buffer.len -= to_copy;
+       } else if (pld->type == PROTOLAYER_PAYLOAD_WIRE_BUF) {
+               size_t wbl = wire_buf_data_length(pld->wire_buf);
+               size_t to_copy = MIN(length, wbl);
+               if (!to_copy)
+                       return NGHTTP2_ERR_PAUSE;
+
                bufiov = (struct iovec){
-                       wire_buf_data(pld.wire_buf),
-                       wire_buf_data_length(pld.wire_buf)
+                       wire_buf_data(pld->wire_buf),
+                       to_copy
                };
                dataiov = &bufiov;
                dataiovcnt = 1;
-       } else if (pld.type == PROTOLAYER_PAYLOAD_IOVEC) {
-               dataiov = pld.iovec.iov;
-               dataiovcnt = pld.iovec.cnt;
+
+               wire_buf_trim(pld->wire_buf, to_copy);
+               if (wire_buf_data_length(pld->wire_buf) == 0) {
+                       wire_buf_reset(pld->wire_buf);
+               }
+       } else if (pld->type == PROTOLAYER_PAYLOAD_IOVEC) {
+               if (pld->iovec.cnt <= 0)
+                       return NGHTTP2_ERR_PAUSE;
+
+               dataiov = pld->iovec.iov;
+               dataiovcnt = 0;
+               size_t avail = 0;
+               for (int i = 0; i < pld->iovec.cnt && avail < length; i++) {
+                       avail += pld->iovec.iov[i].iov_len;
+                       dataiovcnt += 1;
+               }
+
+               /* The actual iovec generation needs to be done later when we
+                * have memory for them. Here, we just count the number of
+                * needed iovecs. */
+               adapt_iovs = true;
        } else {
                kr_assert(false && "Invalid payload");
                protolayer_break(ctx, kr_error(EINVAL));
@@ -422,69 +453,48 @@ static int send_data_callback(nghttp2_session *h2, nghttp2_frame *frame, const u
        struct http_send_data_ctx *sdctx = calloc(iovcnt, sizeof(*ctx) + sizeof(struct iovec[iovcnt]));
        sdctx->padlen = padlen;
 
-       struct iovec *iov = sdctx->iov;
+       struct iovec *dest_iov = sdctx->iov;
        static const uint8_t padding[UINT8_MAX];
 
        int cur = 0;
-       iov[cur++] = (struct iovec){ (void *)framehd, HTTP_FRAME_HDLEN };
+       dest_iov[cur++] = (struct iovec){ (void *)framehd, HTTP_FRAME_HDLEN };
 
        if (has_padding)
-               iov[cur++] = (struct iovec){ &sdctx->padlen, HTTP_FRAME_PADLEN };
-
-       memcpy(&iov[cur], dataiov, sizeof(struct iovec[dataiovcnt]));
-       cur += dataiovcnt;
+               dest_iov[cur++] = (struct iovec){ &sdctx->padlen, HTTP_FRAME_PADLEN };
+
+       if (adapt_iovs) {
+               while (pld->iovec.cnt && length > 0) {
+                       struct iovec *iov = pld->iovec.iov;
+                       size_t to_copy = MIN(length, iov->iov_len);
+
+                       dest_iov[cur++] = (struct iovec){
+                               iov->iov_base, to_copy
+                       };
+                       length -= to_copy;
+                       iov->iov_base = ((char *)iov->iov_base) + to_copy;
+                       iov->iov_len -= to_copy;
+
+                       if (iov->iov_len == 0) {
+                               pld->iovec.iov++;
+                               pld->iovec.cnt--;
+                       }
+               }
+       } else {
+               memcpy(&dest_iov[cur], dataiov, sizeof(struct iovec[dataiovcnt]));
+               cur += dataiovcnt;
+       }
 
        if (has_padding)
-               iov[cur++] = (struct iovec){ (void *)padding, padlen - 1 };
+               dest_iov[cur++] = (struct iovec){ (void *)padding, padlen - 1 };
 
        kr_assert(cur == iovcnt);
        int ret = session2_wrap_after(http->session, PROTOLAYER_HTTP,
-                       protolayer_iovec(iov, cur),
+                       protolayer_iovec(dest_iov, cur),
                        NULL, callback_finished_free_baton, sdctx);
 
        if (ret < 0)
                return ret;
        return 0;
-#else
-       struct protolayer_iter_ctx *ctx = source->ptr;
-       if (kr_fails_assert(ctx)) {
-               return NGHTTP2_ERR_WOULDBLOCK;
-       }
-
-       size_t total_len = HTTP_FRAME_HDLEN + length + frame->data.padlen;
-       struct http_send_ctx *send_ctx = malloc(sizeof(*send_ctx) + total_len);
-       kr_require(send_ctx);
-
-       send_ctx->sess_data = http;
-       uint8_t *cur = send_ctx->data;
-
-       /* TODO - remove these unnecessary copies */
-
-       /* Frame header */
-       memcpy(cur, framehd, HTTP_FRAME_HDLEN);
-       cur += HTTP_FRAME_HDLEN;
-
-       /* Length of frame padding */
-       if (frame->data.padlen) {
-               *cur = frame->data.padlen - 1;
-               cur++;
-       }
-
-       /* Data */
-       size_t copied = protolayer_payload_copy(cur, &ctx->payload, length);
-       cur += copied;
-
-       /* Padding */
-       if (frame->data.padlen > 1)
-               bzero(cur, frame->data.padlen - 1);
-
-       kr_log_debug(DOH, "[%p] send_data_callback: %p\n", (void *)h2, (void *)send_ctx->data);
-       session2_wrap_after(http->session, PROTOLAYER_HTTP,
-                       protolayer_buffer(send_ctx->data, total_len), NULL,
-                       callback_finished_free_baton, send_ctx);
-
-       return 0;
-#endif
 }
 
 /*
@@ -603,6 +613,8 @@ static int header_callback(nghttp2_session *h2, const nghttp2_frame *frame,
        }
 
        if (!strcasecmp("content-type", (const char *)name)) {
+               /* TODO: add a per-group option for content-type if we need to
+                * support protocols other than DNS here */
                if (strcasecmp("application/dns-message", (const char *)value)) {
                        set_status(ctx, HTTP_STATUS_UNSUPPORTED_MEDIA_TYPE);
                        return 0;
@@ -734,6 +746,9 @@ static int on_frame_recv_callback(nghttp2_session *h2, const nghttp2_frame *fram
                        }
                }
 
+               if (!http_status_has_category(ctx->status, 2))
+                       return 0;
+
                if (submit_to_wirebuffer(ctx) < 0)
                        return NGHTTP2_ERR_CALLBACK_FAILURE;
        }
index 52042f49b6e5f71ab927f92cfae65853b642c719..d0cb937b499721f53f1f74705b18407070bffdc9 100644 (file)
@@ -73,7 +73,7 @@ const char *protolayer_protocol_names[PROTOLAYER_PROTOCOL_COUNT] = {
  * one defined as *Variable name* (2nd parameter) in the `PROTOLAYER_GRP_MAP`
  * macro. */
 static enum protolayer_protocol *protolayer_grps[PROTOLAYER_GRP_COUNT] = {
-#define XX(cid, vid, name, alpn) [PROTOLAYER_GRP_##cid] = protolayer_grp_##vid,
+#define XX(cid, vid, name) [PROTOLAYER_GRP_##cid] = protolayer_grp_##vid,
        PROTOLAYER_GRP_MAP(XX)
 #undef XX
 };
@@ -81,7 +81,7 @@ static enum protolayer_protocol *protolayer_grps[PROTOLAYER_GRP_COUNT] = {
 /** Human-readable names for protocol layer groups. */
 const char *protolayer_grp_names[PROTOLAYER_GRP_COUNT] = {
        [PROTOLAYER_GRP_NULL] = "(null)",
-#define XX(cid, vid, name, alpn) [PROTOLAYER_GRP_##cid] = name,
+#define XX(cid, vid, name) [PROTOLAYER_GRP_##cid] = name,
        PROTOLAYER_GRP_MAP(XX)
 #undef XX
 };
index 880b4206a16b5c86982004cfecf330cb835e17bb..afc6f5625f80892c835d04c422c9daa40adc793c 100644 (file)
@@ -122,18 +122,17 @@ extern const char *protolayer_protocol_names[];
  * Parameters are:
  *   1. Constant name (for e.g. PROTOLAYER_GRP_* constants)
  *   2. Variable name (for e.g. protolayer_grp_* arrays)
- *   3. Human-readable name for logging
- *   4. ALPN protocol identifier (for TLS) */
+ *   3. Human-readable name for logging */
 #define PROTOLAYER_GRP_MAP(XX) \
-       XX(DOUDP, doudp, "DNS UDP", "") \
-       XX(DOTCP, dotcp, "DNS TCP", "") \
-       XX(DOTLS, dot, "DNS-over-TLS", "dot") \
-       XX(DOHTTPS, doh, "DNS-over-HTTPS", "h2")
+       XX(DOUDP, doudp, "DNS UDP") \
+       XX(DOTCP, dotcp, "DNS TCP") \
+       XX(DOTLS, dot, "DNS-over-TLS") \
+       XX(DOHTTPS, doh, "DNS-over-HTTPS")
 
 /** The identifiers of pre-defined protocol layer sequences. */
 enum protolayer_grp {
        PROTOLAYER_GRP_NULL = 0,
-#define XX(cid, vid, name, alpn) PROTOLAYER_GRP_##cid,
+#define XX(cid, vid, name) PROTOLAYER_GRP_##cid,
        PROTOLAYER_GRP_MAP(XX)
 #undef XX
        PROTOLAYER_GRP_COUNT
index 2f1773353eb0bbe5df791b16236f5619d5f5a36b..ad9c21f3437b750f9c4aa62b55e3e8adbf5208d5 100644 (file)
 #endif
 
 static const gnutls_datum_t tls_grp_alpn[PROTOLAYER_GRP_COUNT] = {
-#define XX(cid, vid, name, alpn) [PROTOLAYER_GRP_##cid] = \
-       { .data = (unsigned char *)alpn, .size = sizeof(alpn) - 1 },
-       PROTOLAYER_GRP_MAP(XX)
-#undef XX
+       [PROTOLAYER_GRP_DOTLS] = { (uint8_t *)"dot", 3 },
+       [PROTOLAYER_GRP_DOHTTPS] = { (uint8_t *)"h2", 2 },
 };
 
 typedef enum tls_client_hs_state {