From: Charles-Henri Bruyand Date: Mon, 4 Dec 2023 14:58:03 +0000 (+0100) Subject: dnsdist: doh3 add support for post queries X-Git-Tag: dnsdist-1.9.0-alpha4~15^2~5 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=d0439b4299236debc52ea322aa0ce01b47a3a2b9;p=thirdparty%2Fpdns.git dnsdist: doh3 add support for post queries --- diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 8d3334d35d..bc597a6611 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -215,8 +215,8 @@ dnsdist_SOURCES = \ doh.hh \ doh3.hh \ dolog.cc dolog.hh \ - doq.hh \ doq-common.hh \ + doq.hh \ ednscookies.cc ednscookies.hh \ ednsextendederror.cc ednsextendederror.hh \ ednsoptions.cc ednsoptions.hh \ diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index 02a2ea928c..71f3adbd7a 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -42,6 +42,13 @@ #include "doq-common.hh" +#if 0 +#define DEBUGLOG_ENABLED +#define DEBUGLOG(x) std::cerr << x << std::endl; +#else +#define DEBUGLOG(x) +#endif + using namespace dnsdist::doq; class H3Connection @@ -626,7 +633,6 @@ void doh3Thread(ClientState* clientState) Socket sock(clientState->udpFD); - PacketBuffer buffer(std::numeric_limits::max()); auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent()); auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor(); @@ -726,6 +732,7 @@ void doh3Thread(ClientState* clientState) DEBUGLOG("Successfully created HTTP/3 connection"); } + std::map headers; while (true) { quiche_h3_event* ev; // Processes HTTP/3 data received from the peer @@ -739,20 +746,17 @@ void doh3Thread(ClientState* clientState) switch (quiche_h3_event_type(ev)) { case QUICHE_H3_EVENT_HEADERS: { - std::string path; int rc = quiche_h3_event_for_each_header( ev, [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int { std::string_view key(reinterpret_cast(name), name_len); std::string_view content(reinterpret_cast(value), value_len); - if (key == ":path") { - auto pathptr = reinterpret_cast(argp); - *pathptr = content; - } + auto headersptr = reinterpret_cast*>(argp); + headersptr->emplace(key, content); return 0; }, - &path); - if (rc != 0) { + &headers); + if (rc != 0 || !headers.count(":method")) { DEBUGLOG("Failed to process headers"); ++dnsdist::metrics::g_stats.nonCompliantQueries; ++clientState->nonCompliantQueries; @@ -760,63 +764,20 @@ void doh3Thread(ClientState* clientState) h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to process query headers"); break; } - if (path.empty()) { - DEBUGLOG("Path not found"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found"); - break; - } - { - auto pos = path.find("?dns="); - if (pos == string::npos) { - pos = path.find("&dns="); + if (headers.at(":method") == "GET") { + if (!headers.count(":path") || headers.at(":path").empty()) { + DEBUGLOG("Path not found"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found"); + break; } - if (pos != string::npos) { - // need to base64url decode this - string sdns(path.substr(pos + 5)); - boost::replace_all(sdns, "-", "+"); - boost::replace_all(sdns, "_", "/"); - // re-add padding that may have been missing - switch (sdns.size() % 4) { - case 2: - sdns.append(2, '='); - break; - case 3: - sdns.append(1, '='); - break; - } - - PacketBuffer decoded; - - /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */ - const size_t maxAdditionalSizeForEDNS = 35U; - /* rough estimate so we hopefully don't need a new allocation later */ - /* We reserve at few additional bytes to be able to add EDNS later */ - const size_t estimate = ((sdns.size() * 3) / 4); - decoded.reserve(estimate + maxAdditionalSizeForEDNS); - if (B64Decode(sdns, decoded) < 0) { - DEBUGLOG("Unable to base64 decode()"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL"); - break; - } - - if (decoded.size() < sizeof(dnsheader)) { - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query"); - break; - } - DEBUGLOG("Dispatching query"); - doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID); - conn->get().d_streamBuffers.erase(streamID); + auto pos = headers.at(":path").find("?dns="); + if (pos == string::npos) { + pos = headers.at(":path").find("&dns="); } - else { + if (pos == string::npos) { DEBUGLOG("User error, unable to find the DNS parameter"); ++dnsdist::metrics::g_stats.nonCompliantQueries; ++clientState->nonCompliantQueries; @@ -824,11 +785,99 @@ void doh3Thread(ClientState* clientState) h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to find the DNS parameter"); break; } + // need to base64url decode this + string sdns(headers.at(":path").substr(pos + 5)); + boost::replace_all(sdns, "-", "+"); + boost::replace_all(sdns, "_", "/"); + // re-add padding that may have been missing + switch (sdns.size() % 4) { + case 2: + sdns.append(2, '='); + break; + case 3: + sdns.append(1, '='); + break; + } + + PacketBuffer decoded; + /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */ + const size_t maxAdditionalSizeForEDNS = 35U; + /* rough estimate so we hopefully don't need a new allocation later */ + /* We reserve at few additional bytes to be able to add EDNS later */ + const size_t estimate = ((sdns.size() * 3) / 4); + decoded.reserve(estimate + maxAdditionalSizeForEDNS); + if (B64Decode(sdns, decoded) < 0) { + DEBUGLOG("Unable to base64 decode()"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL"); + break; + } + if (decoded.size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query"); + break; + } + DEBUGLOG("Dispatching GET query"); + doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID); + conn->get().d_streamBuffers.erase(streamID); + } + else if (headers.at(":method") == "POST") { + if (!quiche_h3_event_headers_has_body(ev)) { + DEBUGLOG("Empty POST query"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Empty POST query"); + break; + } + } + else { + DEBUGLOG("Unsupported HTTP method"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported HTTP method"); + break; } break; } + case QUICHE_H3_EVENT_DATA: { + if (!headers.count("content-type") || headers.at("content-type") != "application/dns-message") { + DEBUGLOG("Unsupported content-type"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported content-type"); + break; + } + PacketBuffer buffer(std::numeric_limits::max()); + PacketBuffer decoded; + + while (true) { + ssize_t len = quiche_h3_recv_body(conn->get().d_http3.get(), + conn->get().d_conn.get(), streamID, + buffer.data(), buffer.capacity()); - case QUICHE_H3_EVENT_DATA: + if (len <= 0) { + break; + } + decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len); + } + if (decoded.size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query"); + break; + } + DEBUGLOG("Dispatching POST query"); + doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID); + conn->get().d_streamBuffers.erase(streamID); + } case QUICHE_H3_EVENT_FINISHED: case QUICHE_H3_EVENT_RESET: case QUICHE_H3_EVENT_PRIORITY_UPDATE: diff --git a/pdns/dnsdistdist/doq-common.cc b/pdns/dnsdistdist/doq-common.cc index 211928aee5..4b0b2868f9 100644 --- a/pdns/dnsdistdist/doq-common.cc +++ b/pdns/dnsdistdist/doq-common.cc @@ -26,6 +26,13 @@ #ifdef HAVE_DNS_OVER_QUIC +#if 0 +#define DEBUGLOG_ENABLED +#define DEBUGLOG(x) std::cerr << x << std::endl; +#else +#define DEBUGLOG(x) +#endif + namespace dnsdist::doq { diff --git a/pdns/dnsdistdist/doq-common.hh b/pdns/dnsdistdist/doq-common.hh index c8064a8606..7af19ecb79 100644 --- a/pdns/dnsdistdist/doq-common.hh +++ b/pdns/dnsdistdist/doq-common.hh @@ -36,13 +36,6 @@ #include "sstuff.hh" #include "libssl.hh" -#if 0 -#define DEBUGLOG_ENABLED -#define DEBUGLOG(x) std::cerr << x << std::endl; -#else -#define DEBUGLOG(x) -#endif - namespace dnsdist::doq { diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 7037349652..66a71b8504 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -43,6 +43,13 @@ using namespace dnsdist::doq; +#if 0 +#define DEBUGLOG_ENABLED +#define DEBUGLOG(x) std::cerr << x << std::endl; +#else +#define DEBUGLOG(x) +#endif + class Connection { public: @@ -576,7 +583,6 @@ void doqThread(ClientState* clientState) Socket sock(clientState->udpFD); - PacketBuffer buffer(std::numeric_limits::max()); auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent()); auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor(); diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 7b1cca3530..6dff1318eb 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -1129,7 +1129,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): return (receivedQuery, message) @classmethod - def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None): + def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False): if response: if toQueue: @@ -1137,7 +1137,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): else: cls._toResponderQueue.put(response, True, timeout) - message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName) + message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post) receivedQuery = None diff --git a/regression-tests.dnsdist/doh3client.py b/regression-tests.dnsdist/doh3client.py index c04b06b152..85d66a399a 100644 --- a/regression-tests.dnsdist/doh3client.py +++ b/regression-tests.dnsdist/doh3client.py @@ -24,53 +24,10 @@ from aioquic.h3.events import ( ) from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset -#from aioquic.quic.logger import QuicFileLogger from aioquic.tls import CipherSuite, SessionTicket from doqclient import StreamResetError -# -#class DnsClientProtocol(QuicConnectionProtocol): -# def __init__(self, *args, **kwargs): -# super().__init__(*args, **kwargs) -# self._ack_waiter: Any = None -# -# def pack(self, data): -# # serialize query -# data = bytes(data) -# data = struct.pack("!H", len(data)) + data -# return data -# -# async def query(self, query: dns.message) -> None: -# data = self.pack(query.to_wire()) -# # send query and wait for answer -# stream_id = self._quic.get_next_available_stream_id() -# self._quic.send_stream_data(stream_id, data, end_stream=True) -# waiter = self._loop.create_future() -# self._ack_waiter = waiter -# self.transmit() -# -# return await asyncio.shield(waiter) -# -# def quic_event_received(self, event: QuicEvent) -> None: -# if self._ack_waiter is not None: -# if isinstance(event, StreamDataReceived): -# length = struct.unpack("!H", bytes(event.data[:2]))[0] -# answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True) -# -# waiter = self._ack_waiter -# self._ack_waiter = None -# waiter.set_result(answer) -# if isinstance(event, StreamReset): -# waiter = self._ack_waiter -# self._ack_waiter = None -# waiter.set_result(event) -# -#class BogusDnsClientProtocol(DnsClientProtocol): -# def pack(self, data): -# # serialize query -# data = bytes(data) -# data = struct.pack("!H", len(data) * 2) + data -# return data + HttpConnection = Union[H0Connection, H3Connection] class URL: @@ -195,20 +152,19 @@ class HttpClient(QuicConnectionProtocol): async def perform_http_request( client: HttpClient, url: str, - data: Optional[str], + data: Optional[bytes], include: bool, output_dir: Optional[str], ) -> None: # perform request start = time.time() if data is not None: - data_bytes = data.encode() http_events = await client.post( url, - data=data_bytes, + data=data, headers={ - "content-length": str(len(data_bytes)), - "content-type": "application/x-www-form-urlencoded", + "content-length": str(len(data)), + "content-type": "application/dns-message", }, ) method = "POST" @@ -232,10 +188,13 @@ async def async_h3_query( port: int, query: dns.message, timeout: float, - create_protocol=HttpClient + post: bool, + create_protocol=HttpClient, ) -> None: - url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('=')) + url = baseurl + if not post: + url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('=')) async with connect( "127.0.0.1", port, @@ -250,7 +209,7 @@ async def async_h3_query( answer = await perform_http_request( client=client, url=url, - data=None, + data=query.to_wire() if post else None, include=False, output_dir=None, ) @@ -259,10 +218,12 @@ async def async_h3_query( except asyncio.TimeoutError as e: return e -def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None): + +def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False): configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True) if verify: configuration.load_verify_locations(verify) + result = asyncio.run( async_h3_query( configuration=configuration, @@ -270,7 +231,8 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname port=port, query=query, timeout=timeout, - create_protocol=HttpClient + create_protocol=HttpClient, + post=post ) ) diff --git a/regression-tests.dnsdist/test_DOH3.py b/regression-tests.dnsdist/test_DOH3.py index dcff35e096..4374bc2a9f 100644 --- a/regression-tests.dnsdist/test_DOH3.py +++ b/regression-tests.dnsdist/test_DOH3.py @@ -32,3 +32,41 @@ class TestDOH3(QUICTests, DNSDistTest): def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection) + +class TestDOH3Specifics(DNSDistTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _doqServerPort = pickAvailablePort() + _dohBaseURL = ("https://%s:%d/" % (_serverName, _doqServerPort)) + _config_template = """ + newServer{address="127.0.0.1:%d"} + + addDOH3Local("127.0.0.1:%d", "%s", "%s", {keyLogFile='/tmp/keys'}) + """ + _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] + _verboseMode = True + + def testDOH3Post(self): + """ + QUIC: Simple POST query + """ + name = 'simple.post.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, serverName=self._serverName, post=True) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.assertEqual(receivedResponse, response)