]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon/http: return 400 on failed packet_parse + improved stream handling
authorOto Šťáva <oto.stava@nic.cz>
Fri, 1 Apr 2022 08:42:36 +0000 (10:42 +0200)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Fri, 20 May 2022 07:45:34 +0000 (09:45 +0200)
daemon/http.c
daemon/io.c
daemon/worker.c
tests/config/doh2.test.lua

index d3f8d2cd6a0b5cf466a3c27e3bc57ed09e66e5e6..47dc75389922f0e7dd1a28b21730bfb4f4f8fe87 100644 (file)
@@ -57,6 +57,7 @@ enum http_status {
        HTTP_STATUS_BAD_REQUEST                     = 400,
        HTTP_STATUS_NOT_FOUND                       = 404,
        HTTP_STATUS_PAYLOAD_TOO_LARGE               = 413,
+       HTTP_STATUS_UNSUPPORTED_MEDIA_TYPE          = 415,
        HTTP_STATUS_REQUEST_HEADER_FIELDS_TOO_LARGE = 431,
        HTTP_STATUS_NOT_IMPLEMENTED                 = 501,
 };
@@ -74,6 +75,8 @@ typedef array_t(nghttp2_nv) nghttp2_array_t;
 
 static int http_send_response(struct http_ctx *ctx, int32_t stream_id,
                              nghttp2_data_provider *prov, enum http_status status);
+static int http_send_response_rst_stream(struct http_ctx *ctx, int32_t stream_id,
+                             nghttp2_data_provider *prov, enum http_status status);
 
 /*
  * Write HTTP/2 protocol data to underlying transport layer.
@@ -219,7 +222,7 @@ static int check_uri(const char* uri_path)
                        }
                        end_prev = beg + strlen(beg);
                        beg = strtok(NULL, delim);
-                       if (beg-1 != end_prev) { /* detect && */
+                       if (!beg || beg-1 != end_prev) { /* detect && */
                                return -1;
                        }
                }
@@ -368,7 +371,7 @@ static int header_callback(nghttp2_session *h2, const nghttp2_frame *frame,
                                kr_log_debug(DOH,
                                        "[%p] stream %d: header too large (%zu B), refused\n",
                                        (void *)h2, stream_id, valuelen);
-                               return http_send_response(ctx, stream_id, NULL,
+                               return http_send_response_rst_stream(ctx, stream_id, NULL,
                                                HTTP_STATUS_REQUEST_HEADER_FIELDS_TOO_LARGE);
                        }
 
@@ -389,9 +392,11 @@ static int header_callback(nghttp2_session *h2, const nghttp2_frame *frame,
        if (!strcasecmp(":path", (const char *)name)) {
                int uri_result = check_uri((const char *)value);
                if (uri_result == kr_error(ENOENT)) {
-                       return http_send_response(ctx, stream_id, NULL, HTTP_STATUS_NOT_FOUND);
+                       return http_send_response_rst_stream(ctx, stream_id, NULL,
+                                       HTTP_STATUS_NOT_FOUND);
                } else if (uri_result < 0) {
-                       return http_send_response(ctx, stream_id, NULL, HTTP_STATUS_BAD_REQUEST);
+                       return http_send_response_rst_stream(ctx, stream_id, NULL,
+                                       HTTP_STATUS_BAD_REQUEST);
                }
 
                kr_assert(ctx->uri_path == NULL);
@@ -411,7 +416,15 @@ static int header_callback(nghttp2_session *h2, const nghttp2_frame *frame,
                        ctx->current_method = HTTP_METHOD_HEAD;
                } else {
                        ctx->current_method = HTTP_METHOD_NONE;
-                       return http_send_response(ctx, stream_id, NULL, HTTP_STATUS_NOT_IMPLEMENTED);
+                       return http_send_response_rst_stream(ctx, stream_id, NULL,
+                                       HTTP_STATUS_NOT_IMPLEMENTED);
+               }
+       }
+
+       if (!strcasecmp("content-type", (const char *)name)) {
+               if (strcasecmp("application/dns-message", (const char *)value)) {
+                       return http_send_response_rst_stream(ctx, stream_id, NULL,
+                                       HTTP_STATUS_UNSUPPORTED_MEDIA_TYPE);
                }
        }
 
@@ -496,7 +509,7 @@ static int submit_to_wirebuffer(struct http_ctx *ctx)
        len = ctx->buf_pos - sizeof(uint16_t);
        if (len <= 0 || len > KNOT_WIRE_MAX_PKTSIZE) {
                kr_log_debug(DOH, "[%p] invalid dnsmsg size: %zd B\n", (void *)ctx->h2, len);
-               http_send_response(ctx, stream_id, NULL, (len <= 0)
+               http_send_response_rst_stream(ctx, stream_id, NULL, (len <= 0)
                                ? HTTP_STATUS_BAD_REQUEST
                                : HTTP_STATUS_PAYLOAD_TOO_LARGE);
                ret = -1;
@@ -536,7 +549,8 @@ static int on_frame_recv_callback(nghttp2_session *h2, const nghttp2_frame *fram
                if (ctx->current_method == HTTP_METHOD_GET || ctx->current_method == HTTP_METHOD_HEAD) {
                        if (process_uri_path(ctx, ctx->uri_path, stream_id) < 0) {
                                /* End processing - don't submit to wirebuffer. */
-                               return http_send_response(ctx, stream_id, NULL, HTTP_STATUS_BAD_REQUEST);
+                               return http_send_response_rst_stream(ctx, stream_id, NULL,
+                                               HTTP_STATUS_BAD_REQUEST);
                        }
                }
 
@@ -623,7 +637,9 @@ struct http_ctx* http_new(struct session *session, http_send_callback send_cb)
        queue_init(ctx->streams);
        ctx->stream_write_data = trie_create(NULL);
        ctx->incomplete_stream = -1;
+       ctx->submitted_stream = -1;
        ctx->submitted = 0;
+       ctx->streaming = true;
        ctx->current_method = HTTP_METHOD_NONE;
        ctx->uri_path = NULL;
 
@@ -691,7 +707,8 @@ int http_send_bad_request(struct session *session)
 {
        struct http_ctx *ctx = session_http_get_server_ctx(session);
        if (ctx->submitted_stream >= 0)
-               return http_send_response(ctx, ctx->submitted_stream, NULL, HTTP_STATUS_BAD_REQUEST);
+               return http_send_response_rst_stream(ctx, ctx->submitted_stream, NULL,
+                               HTTP_STATUS_BAD_REQUEST);
 
        return 0;
 }
@@ -814,13 +831,27 @@ static int http_send_response(struct http_ctx *ctx, int32_t stream_id,
                return 0;
        }
 
-       if (status != HTTP_STATUS_OK) {
-               nghttp2_submit_rst_stream(h2, NGHTTP2_FLAG_NONE, stream_id, NGHTTP2_NO_ERROR);
-       }
-
        return 0;
 }
 
+/*
+ * Same as `http_send_response`, but resets the HTTP stream afterwards. Used
+ * for sending negative status messages.
+ */
+static int http_send_response_rst_stream(struct http_ctx *ctx, int32_t stream_id,
+                             nghttp2_data_provider *prov, enum http_status status)
+{
+       int ret = http_send_response(ctx, stream_id, prov, status);
+       if (ret)
+               return ret;
+
+       ctx->submitted_stream = -1;
+       nghttp2_submit_rst_stream(ctx->h2, NGHTTP2_FLAG_NONE, stream_id, NGHTTP2_NO_ERROR);
+       ret = nghttp2_session_send(ctx->h2);
+       return ret;
+}
+
+
 /*
  * Send HTTP/2 stream data created from packet's wire buffer.
  *
index ff2e25a50c5963ec9420fad8304a4d0d4e17d90f..30b3f453ef2d22fc545e7d0c3155f2439e7191da 100644 (file)
@@ -478,7 +478,7 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
 #if ENABLE_DOH2
        if (session_flags(s)->has_http && streaming == 0 && ret == 0) {
                ret = http_send_bad_request(s);
-               if (ret < 0) {
+               if (ret) {
                        /* An error has occurred, close the session. */
                        worker_end_tcp(s);
                }
index b2e421abbb0508204ca8e7918184f845fad332f0..156335d314047fe5fb33ebb6292fdcc11cf50620 100644 (file)
@@ -1778,6 +1778,10 @@ int worker_submit(struct session *session, struct io_comm_data *comm,
        struct http_ctx *http_ctx = NULL;
 #if ENABLE_DOH2
        http_ctx = session_http_get_server_ctx(session);
+       if (http_ctx && !is_outgoing && ret) {
+               http_send_bad_request(session);
+               return kr_error(EMSGSIZE);
+       }
 #endif
 
        if (!is_outgoing && http_ctx && queue_len(http_ctx->streams) <= 0)
index 429963c9a7ee6373c174388a92d5c53c4be669fd..352ea42834285d6c90410932aa6a005f1f6d782d 100644 (file)
@@ -63,15 +63,15 @@ local function check_ok(req, desc)
        return headers, pkt
 end
 
---local function check_err(req, exp_status, desc)
---     local headers, errmsg, errno = req:go(16)
---     if errno then
---             nok(errmsg, desc .. ': ' .. errmsg)
---             return
---     end
---     local got_status = headers:get(':status')
---     same(got_status, exp_status, desc)
---end
+local function check_err(req, exp_status, desc)
+       local headers, errmsg, errno = req:go(16)
+       if errno then
+               nok(errmsg, desc .. ': ' .. errmsg)
+               return
+       end
+       local got_status = headers:get(':status')
+       same(got_status, exp_status, desc)
+end
 
 -- check prerequisites
 local bound, port
@@ -169,35 +169,39 @@ else
        end
 
        -- test an invalid DNS query using POST
---     local function test_post_short_input()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'POST')
---             req:set_body(string.rep('0', 11))  -- 11 bytes < DNS msg header
---             check_err(req, '400', 'too short POST finishes with 400')
---     end
---
+       local function test_post_short_input()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'POST')
+               req:set_body(string.rep('0', 11))  -- 11 bytes < DNS msg header
+               check_err(req, '400', 'too short POST finishes with 400')
+       end
+
 --     local function test_post_long_input()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'POST')
---             req:set_body(string.rep('s', 1025))  -- > DNS msg over UDP
---             check_err(req, '413', 'too long POST finishes with 413')
---     end
---
---     local function test_post_unparseable_input()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'POST')
---             req:set_body(string.rep('\0', 1024))  -- garbage
---             check_err(req, '400', 'unparseable DNS message finishes with 400')
---     end
+--             -- FIXME: This test is broken in Lua. The connection times out
+--             -- for some reason, but sending a request like this with `curl`
+--             -- or PowerShell's `Invoke-RestMethod` provides correct results.
 --
---     local function test_post_unsupp_type()
 --             local req = assert(req_templ:clone())
 --             req.headers:upsert(':method', 'POST')
---             req.headers:upsert('content-type', 'application/dns+json')
---             req:set_body(string.rep('\0', 12))  -- valid message
---             check_err(req, '415', 'unsupported request content type finishes with 415')
+--             req:set_body(string.rep('s', 1025))  -- > DNS msg over UDP
+--             check_err(req, '400', 'too long POST finishes with 400')
 --     end
 
+       local function test_post_unparseable_input()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'POST')
+               req:set_body(string.rep('\0', 1024))  -- garbage
+               check_err(req, '400', 'unparseable DNS message finishes with 400')
+       end
+
+       local function test_post_unsupp_type()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'POST')
+               req.headers:upsert('content-type', 'application/dns+json')
+               req:set_body(string.rep('\0', 12))  -- valid message
+               check_err(req, '415', 'unsupported request content type finishes with 415')
+       end
+
        -- test a valid DNS query using GET
        local function test_get_servfail()
                local desc = 'valid GET query which ends with SERVFAIL'
@@ -275,47 +279,47 @@ else
                check_ok(req, desc)
        end
 
---     -- test an invalid DNS query using GET
---             local function test_get_long_input()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'GET')
---             req.headers:upsert(':path', '/doh?dns=' .. basexx.to_url64(string.rep('\0', 1030)))
---             check_err(req, '414', 'too long GET finishes with 414')
---     end
---
---     local function test_get_no_dns_param()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'GET')
---             req.headers:upsert(':path', '/doh?notdns=' .. basexx.to_url64(string.rep('\0', 1024)))
---             check_err(req, '400', 'GET without dns parameter finishes with 400')
---     end
---
---     local function test_get_unparseable()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'GET')
---             req.headers:upsert(':path', '/doh??dns=' .. basexx.to_url64(string.rep('\0', 1024)))
---             check_err(req, '400', 'unparseable GET finishes with 400')
---     end
---
---     local function test_get_invalid_b64()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'GET')
---             req.headers:upsert(':path', '/doh?dns=thisisnotb64')
---             check_err(req, '400', 'GET with invalid base64 finishes with 400')
---     end
---
---     local function test_get_invalid_chars()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'GET')
---             req.headers:upsert(':path', '/doh?dns=' .. basexx.to_url64(string.rep('\0', 200)) .. '@#$%?!')
---             check_err(req, '400', 'GET with invalid characters in b64 finishes with 400')
---     end
---
---     local function test_unsupp_method()
---             local req = assert(req_templ:clone())
---             req.headers:upsert(':method', 'PUT')
---             check_err(req, '405', 'unsupported method finishes with 405')
---     end
+       -- test an invalid DNS query using GET
+               local function test_get_long_input()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'GET')
+               req.headers:upsert(':path', '/doh?dns=' .. basexx.to_url64(string.rep('\0', 1030)))
+               check_err(req, '400', 'too long GET finishes with 400')
+       end
+
+       local function test_get_no_dns_param()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'GET')
+               req.headers:upsert(':path', '/doh?notdns=' .. basexx.to_url64(string.rep('\0', 1024)))
+               check_err(req, '400', 'GET without dns parameter finishes with 400')
+       end
+
+       local function test_get_unparseable()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'GET')
+               req.headers:upsert(':path', '/doh??dns=' .. basexx.to_url64(string.rep('\0', 1024)))
+               check_err(req, '400', 'unparseable GET finishes with 400')
+       end
+
+       local function test_get_invalid_b64()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'GET')
+               req.headers:upsert(':path', '/doh?dns=thisisnotb64')
+               check_err(req, '400', 'GET with invalid base64 finishes with 400')
+       end
+
+       local function test_get_invalid_chars()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'GET')
+               req.headers:upsert(':path', '/doh?dns=' .. basexx.to_url64(string.rep('\0', 200)) .. '@#$%?!')
+               check_err(req, '400', 'GET with invalid characters in b64 finishes with 400')
+       end
+
+       local function test_unsupp_method()
+               local req = assert(req_templ:clone())
+               req.headers:upsert(':method', 'PUT')
+               check_err(req, '501', 'unsupported method finishes with 501')
+       end
 
        local function test_dstaddr()
                local triggered = false
@@ -438,29 +442,28 @@ else
 --     end
 
        -- plan tests
-       -- TODO: implement (some) of the error status codes
        local tests = {
                start_server,
                test_post_servfail,
                test_post_noerror,
                test_post_nxdomain,
                test_huge_answer,
-               --test_post_short_input,
-               --test_post_long_input,
-               --test_post_unparseable_input,
-               --test_post_unsupp_type,
+               test_post_short_input,
+--             test_post_long_input, -- FIXME see the test function
+               test_post_unparseable_input,
+               test_post_unsupp_type,
                test_get_servfail,
                test_get_noerror,
                test_get_nxdomain,
                test_get_other_params_before_dns,
                test_get_other_params_after_dns,
                test_get_other_params,
-               --test_get_long_input,
-               --test_get_no_dns_param,
-               --test_get_unparseable,
-               --test_get_invalid_b64,
-               --test_get_invalid_chars,
-               --test_unsupp_method,
+               test_get_long_input,
+               test_get_no_dns_param,
+               test_get_unparseable,
+               test_get_invalid_b64,
+               test_get_invalid_chars,
+               test_unsupp_method,
                test_dstaddr,
                test_srcaddr,
                test_headers