From: Remi Gacogne Date: Mon, 30 Dec 2024 14:51:01 +0000 (+0100) Subject: dnsdist: Add DoH3 headers, query string, path and scheme bindings X-Git-Tag: dnsdist-2.0.0-alpha1~187^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6c419450e7dcbf20f427349ec3dc174caa7f2a43;p=thirdparty%2Fpdns.git dnsdist: Add DoH3 headers, query string, path and scheme bindings The DoH ones have been there for a long time, but the DoH3 ones were missing. Note that we still don't have the ability to set a HTTP response for DoH3 queries (including response maps) and SNI is still missing (Quiche does not make that last one easy). --- diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 4c9f650ce3..2e445dff40 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -480,12 +480,15 @@ endif if HAVE_DNS_OVER_HTTP3 dnsdist_SOURCES += doh3.cc +testrunner_SOURCES += doh3.cc endif if HAVE_QUICHE AM_CPPFLAGS += $(QUICHE_CFLAGS) dnsdist_LDADD += $(QUICHE_LDFLAGS) $(QUICHE_LIBS) dnsdist_SOURCES += doq-common.cc +testrunner_SOURCES += doq-common.cc +testrunner_LDADD += $(QUICHE_LDFLAGS) $(QUICHE_LIBS) endif if !HAVE_LUA_HPP diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index 4ff9ff2900..d6f48b9390 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -32,15 +32,23 @@ HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& reg bool HTTPHeaderRule::matches(const DNSQuestion* dq) const { - if (!dq->ids.du) { + if (dq->ids.du) { + const auto& headers = dq->ids.du->getHTTPHeaders(); + for (const auto& header : headers) { + if (header.first == d_header) { + return d_regex.match(header.second); + } + } return false; } - - const auto& headers = dq->ids.du->getHTTPHeaders(); - for (const auto& header : headers) { - if (header.first == d_header) { - return d_regex.match(header.second); + if (dq->ids.doh3u) { + const auto& headers = dq->ids.doh3u->getHTTPHeaders(); + for (const auto& header : headers) { + if (header.first == d_header) { + return d_regex.match(header.second); + } } + return false; } return false; } @@ -57,12 +65,14 @@ HTTPPathRule::HTTPPathRule(std::string path) : bool HTTPPathRule::matches(const DNSQuestion* dq) const { - if (!dq->ids.du) { - return false; + if (dq->ids.du) { + const auto path = dq->ids.du->getHTTPPath(); + return d_path == path; } - - const auto path = dq->ids.du->getHTTPPath(); - return d_path == path; + else if (dq->ids.doh3u) { + return dq->ids.doh3u->getHTTPPath() == d_path; + } + return false; } string HTTPPathRule::toString() const @@ -77,11 +87,14 @@ HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex) : bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const { - if (!dq->ids.du) { - return false; + if (dq->ids.du) { + const auto path = dq->ids.du->getHTTPPath(); + return d_regex.match(path); } - - return d_regex.match(dq->ids.du->getHTTPPath()); + else if (dq->ids.doh3u) { + return d_regex.match(dq->ids.doh3u->getHTTPPath()); + } + return false; } string HTTPPathRegexRule::toString() const diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc index 3527ce7ba2..58546e1029 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc @@ -513,38 +513,53 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) #ifdef HAVE_DNS_OVER_HTTPS luaCtx.registerFunction("getHTTPPath", [](const DNSQuestion& dnsQuestion) { - if (dnsQuestion.ids.du == nullptr) { - return std::string(); + if (dnsQuestion.ids.du) { + return dnsQuestion.ids.du->getHTTPPath(); + } + if (dnsQuestion.ids.doh3u) { + return dnsQuestion.ids.doh3u->getHTTPPath(); } - return dnsQuestion.ids.du->getHTTPPath(); + return std::string(); }); luaCtx.registerFunction("getHTTPQueryString", [](const DNSQuestion& dnsQuestion) { - if (dnsQuestion.ids.du == nullptr) { - return std::string(); + if (dnsQuestion.ids.du) { + return dnsQuestion.ids.du->getHTTPQueryString(); } - return dnsQuestion.ids.du->getHTTPQueryString(); + if (dnsQuestion.ids.doh3u) { + return dnsQuestion.ids.doh3u->getHTTPQueryString(); + } + return std::string(); }); luaCtx.registerFunction("getHTTPHost", [](const DNSQuestion& dnsQuestion) { - if (dnsQuestion.ids.du == nullptr) { - return std::string(); + if (dnsQuestion.ids.du) { + return dnsQuestion.ids.du->getHTTPHost(); } - return dnsQuestion.ids.du->getHTTPHost(); + if (dnsQuestion.ids.doh3u) { + return dnsQuestion.ids.doh3u->getHTTPHost(); + } + return std::string(); }); luaCtx.registerFunction("getHTTPScheme", [](const DNSQuestion& dnsQuestion) { - if (dnsQuestion.ids.du == nullptr) { - return std::string(); + if (dnsQuestion.ids.du) { + return dnsQuestion.ids.du->getHTTPScheme(); + } + if (dnsQuestion.ids.doh3u) { + return dnsQuestion.ids.doh3u->getHTTPScheme(); } - return dnsQuestion.ids.du->getHTTPScheme(); + return std::string(); }); luaCtx.registerFunction (DNSQuestion::*)(void) const>("getHTTPHeaders", [](const DNSQuestion& dnsQuestion) { - if (dnsQuestion.ids.du == nullptr) { - return LuaAssociativeTable(); + if (dnsQuestion.ids.du) { + return dnsQuestion.ids.du->getHTTPHeaders(); + } + if (dnsQuestion.ids.doh3u) { + return dnsQuestion.ids.doh3u->getHTTPHeaders(); } - return dnsQuestion.ids.du->getHTTPHeaders(); + return LuaAssociativeTable(); }); luaCtx.registerFunction contentType)>("setHTTPResponse", [](DNSQuestion& dnsQuestion, uint64_t statusCode, const std::string& body, const boost::optional& contentType) { diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 03ffb98a26..d531cb95dd 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -294,12 +294,16 @@ size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq, const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpPath) { - if (dq->dq->ids.du == nullptr) { - return nullptr; - } -#ifdef HAVE_DNS_OVER_HTTPS - dq->httpPath = dq->dq->ids.du->getHTTPPath(); + if (dq->dq->ids.du) { +#if defined(HAVE_DNS_OVER_HTTPS) + dq->httpPath = dq->dq->ids.du->getHTTPPath(); #endif /* HAVE_DNS_OVER_HTTPS */ + } + else if (dq->dq->ids.doh3u) { +#if defined(HAVE_DNS_OVER_HTTP3) + dq->httpPath = dq->dq->ids.doh3u->getHTTPPath(); +#endif /* HAVE_DNS_OVER_HTTP3 */ + } } if (dq->httpPath) { return dq->httpPath->c_str(); @@ -310,12 +314,16 @@ const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpQueryString) { - if (dq->dq->ids.du == nullptr) { - return nullptr; - } + if (dq->dq->ids.du) { #ifdef HAVE_DNS_OVER_HTTPS - dq->httpQueryString = dq->dq->ids.du->getHTTPQueryString(); + dq->httpQueryString = dq->dq->ids.du->getHTTPQueryString(); #endif /* HAVE_DNS_OVER_HTTPS */ + } + else if (dq->dq->ids.doh3u) { +#if defined(HAVE_DNS_OVER_HTTP3) + dq->httpQueryString = dq->dq->ids.doh3u->getHTTPQueryString(); +#endif /* HAVE_DNS_OVER_HTTP3 */ + } } if (dq->httpQueryString) { return dq->httpQueryString->c_str(); @@ -326,12 +334,16 @@ const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestio const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpHost) { - if (dq->dq->ids.du == nullptr) { - return nullptr; - } + if (dq->dq->ids.du) { #ifdef HAVE_DNS_OVER_HTTPS - dq->httpHost = dq->dq->ids.du->getHTTPHost(); + dq->httpHost = dq->dq->ids.du->getHTTPHost(); #endif /* HAVE_DNS_OVER_HTTPS */ + } + else if (dq->dq->ids.doh3u) { +#if defined(HAVE_DNS_OVER_HTTP3) + dq->httpHost = dq->dq->ids.doh3u->getHTTPHost(); +#endif /* HAVE_DNS_OVER_HTTP3 */ + } } if (dq->httpHost) { return dq->httpHost->c_str(); @@ -342,12 +354,16 @@ const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpScheme) { - if (dq->dq->ids.du == nullptr) { - return nullptr; - } + if (dq->dq->ids.du) { #ifdef HAVE_DNS_OVER_HTTPS - dq->httpScheme = dq->dq->ids.du->getHTTPScheme(); + dq->httpScheme = dq->dq->ids.du->getHTTPScheme(); #endif /* HAVE_DNS_OVER_HTTPS */ + } + else if (dq->dq->ids.doh3u) { +#if defined(HAVE_DNS_OVER_HTTP3) + dq->httpScheme = dq->dq->ids.doh3u->getHTTPScheme(); +#endif /* HAVE_DNS_OVER_HTTP3 */ + } } if (dq->httpScheme) { return dq->httpScheme->c_str(); @@ -404,36 +420,45 @@ size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* dq, c size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_http_header_t** out) { - if (dq->dq->ids.du == nullptr) { - return 0; - } +#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3) + const auto processHeaders = [&dq](const std::unordered_map& headers) { + if (headers.size() == 0) { + return; + } + dq->httpHeaders = std::make_unique>(std::move(headers)); + if (!dq->httpHeadersVect) { + dq->httpHeadersVect = std::make_unique>(); + } + dq->httpHeadersVect->clear(); + dq->httpHeadersVect->resize(dq->httpHeaders->size()); + size_t pos = 0; + for (const auto& header : *dq->httpHeaders) { + dq->httpHeadersVect->at(pos).name = header.first.c_str(); + dq->httpHeadersVect->at(pos).value = header.second.c_str(); + ++pos; + } + }; -#ifdef HAVE_DNS_OVER_HTTPS - auto headers = dq->dq->ids.du->getHTTPHeaders(); - if (headers.size() == 0) { - return 0; +#if defined(HAVE_DNS_OVER_HTTPS) + if (dq->dq->ids.du) { + const auto& headers = dq->dq->ids.du->getHTTPHeaders(); + processHeaders(headers); } - dq->httpHeaders = std::make_unique>(std::move(headers)); - if (!dq->httpHeadersVect) { - dq->httpHeadersVect = std::make_unique>(); - } - dq->httpHeadersVect->clear(); - dq->httpHeadersVect->resize(dq->httpHeaders->size()); - size_t pos = 0; - for (const auto& header : *dq->httpHeaders) { - dq->httpHeadersVect->at(pos).name = header.first.c_str(); - dq->httpHeadersVect->at(pos).value = header.second.c_str(); - ++pos; +#endif /* HAVE_DNS_OVER_HTTPS */ +#if defined(HAVE_DNS_OVER_HTTP3) + if (dq->dq->ids.doh3u) { + const auto& headers = dq->dq->ids.doh3u->getHTTPHeaders(); + processHeaders(headers); } +#endif /* HAVE_DNS_OVER_HTTP3 */ if (!dq->httpHeadersVect->empty()) { *out = dq->httpHeadersVect->data(); } - return dq->httpHeadersVect->size(); -#else +#else /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */ return 0; -#endif +#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */ } size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out) diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index 661e9c6182..e988805073 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -49,8 +49,6 @@ using namespace dnsdist::doq; -using h3_headers_t = std::map; - class H3Connection { public: @@ -70,7 +68,7 @@ public: QuicheConfig d_config; QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free}; // buffer request headers by streamID - std::unordered_map d_headersBuffers; + std::unordered_map d_headersBuffers; std::unordered_map d_streamBuffers; std::unordered_map d_streamOutBuffers; }; @@ -629,7 +627,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit) } } -static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID) +static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers) { try { auto unit = std::make_unique(std::move(query)); @@ -639,6 +637,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con unit->ids.protocol = dnsdist::Protocol::DoH3; unit->serverConnID = serverConnID; unit->streamID = streamID; + unit->headers = std::move(headers); processDOH3Query(std::move(unit)); } @@ -706,7 +705,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API std::string_view content(reinterpret_cast(value), value_len); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API - auto* headersptr = reinterpret_cast(argp); + auto* headersptr = reinterpret_cast(argp); headersptr->emplace(key, content); return 0; }, @@ -739,7 +738,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten return; } DEBUGLOG("Dispatching GET query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID); + doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers)); conn.d_streamBuffers.erase(streamID); conn.d_headersBuffers.erase(streamID); return; @@ -804,7 +803,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, } DEBUGLOG("Dispatching POST query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID); + doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers)); conn.d_headersBuffers.erase(streamID); conn.d_streamBuffers.erase(streamID); } @@ -821,7 +820,7 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3 if (streamID < 0) { break; } - conn.d_headersBuffers.try_emplace(streamID, h3_headers_t{}); + conn.d_headersBuffers.try_emplace(streamID, dnsdist::doh3::h3_headers_t{}); switch (quiche_h3_event_type(event)) { case QUICHE_H3_EVENT_HEADERS: { @@ -1035,4 +1034,76 @@ void doh3Thread(ClientState* clientState) } } +std::string DOH3Unit::getHTTPPath() const +{ + const auto& path = headers.at(":path"); + auto pos = path.find('?'); + if (pos == string::npos) { + return path; + } + return path.substr(0, pos); +} + +std::string DOH3Unit::getHTTPQueryString() const +{ + const auto& path = headers.at(":path"); + auto pos = path.find('?'); + if (pos == string::npos) { + return std::string(); + } + + return path.substr(pos); +} + +std::string DOH3Unit::getHTTPHost() const +{ + const auto& host = headers.find(":authority"); + if (host == headers.end()) { + return {}; + } + return host->second; +} + +std::string DOH3Unit::getHTTPScheme() const +{ + const auto& scheme = headers.find(":scheme"); + if (scheme == headers.end()) { + return {}; + } + return scheme->second; +} + +const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const +{ + return headers; +} + +#else /* HAVE_DNS_OVER_HTTP3 */ + +std::string DOH3Unit::getHTTPPath() const +{ + return std::string(); +} + +std::string DOH3Unit::getHTTPQueryString() const +{ + return std::string(); +} + +std::string DOH3Unit::getHTTPHost() const +{ + return std::string(); +} + +std::string DOH3Unit::getHTTPScheme() const +{ + return std::string(); +} + +const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const +{ + static const dnsdist::doh3::h3_headers_t headers; + return headers; +} + #endif /* HAVE_DNS_OVER_HTTP3 */ diff --git a/pdns/dnsdistdist/doh3.hh b/pdns/dnsdistdist/doh3.hh index 954ea4aab2..0288ad1445 100644 --- a/pdns/dnsdistdist/doh3.hh +++ b/pdns/dnsdistdist/doh3.hh @@ -22,6 +22,7 @@ #pragma once #include +#include #include "config.h" #include "channel.hh" @@ -34,6 +35,11 @@ struct DOH3ServerConfig; struct DownstreamState; +namespace dnsdist::doh3 +{ +using h3_headers_t = std::unordered_map; +} + #ifdef HAVE_DNS_OVER_HTTP3 #include "doq-common.hh" @@ -78,10 +84,17 @@ struct DOH3Unit DOH3Unit(const DOH3Unit&) = delete; DOH3Unit& operator=(const DOH3Unit&) = delete; + [[nodiscard]] std::string getHTTPPath() const; + [[nodiscard]] std::string getHTTPQueryString() const; + [[nodiscard]] std::string getHTTPHost() const; + [[nodiscard]] std::string getHTTPScheme() const; + [[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const; + InternalQueryState ids; PacketBuffer query; PacketBuffer response; PacketBuffer serverConnID; + dnsdist::doh3::h3_headers_t headers; std::shared_ptr downstream{nullptr}; DOH3ServerConfig* dsc{nullptr}; uint64_t streamID{0}; @@ -104,6 +117,11 @@ void doh3Thread(ClientState* clientState); struct DOH3Unit { + std::string getHTTPPath() const; + std::string getHTTPQueryString() const; + const std::string& getHTTPHost() const; + const std::string& getHTTPScheme() const; + const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const; }; struct DOH3Frontend diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index bcf6ea10aa..1bf6a66bb8 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -55,6 +55,10 @@ void handleResponseSent(const InternalQueryState& ids, double udiff, const Combo { } +void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol, bool fromBackend) +{ +} + std::function& selectedBackend)> s_processQuery; ProcessQueryResult processQuery(DNSQuestion& dnsQuestion, std::shared_ptr& selectedBackend) diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 994b75168e..3515e75d04 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -1151,7 +1151,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): return (receivedQuery, message) @classmethod - def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False): + def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None): if response: if toQueue: @@ -1159,7 +1159,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): else: cls._toResponderQueue.put(response, True, timeout) - message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post) + message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders) receivedQuery = None diff --git a/regression-tests.dnsdist/doh3client.py b/regression-tests.dnsdist/doh3client.py index 85d66a399a..c1a1ae784d 100644 --- a/regression-tests.dnsdist/doh3client.py +++ b/regression-tests.dnsdist/doh3client.py @@ -1,4 +1,5 @@ import base64 +import copy import asyncio import pickle import ssl @@ -133,7 +134,7 @@ class HttpClient(QuicConnectionProtocol): (b":authority", request.url.authority.encode()), (b":path", request.url.full_path.encode()), ] - + [(k.encode(), v.encode()) for (k, v) in request.headers.items()], + + [(k.lower().encode(), v.encode()) for (k, v) in request.headers.items()], end_stream=not request.content, ) if request.content: @@ -155,21 +156,22 @@ async def perform_http_request( data: Optional[bytes], include: bool, output_dir: Optional[str], + additional_headers: Optional[Dict] = None, ) -> None: # perform request start = time.time() if data is not None: + headers = copy.deepcopy(additional_headers) if additional_headers else {} + headers["content-length"] = str(len(data)) + headers["content-type"] = "application/dns-message" http_events = await client.post( url, data=data, - headers={ - "content-length": str(len(data)), - "content-type": "application/dns-message", - }, + headers=headers, ) method = "POST" else: - http_events = await client.get(url) + http_events = await client.get(url, headers=additional_headers) method = "GET" elapsed = time.time() - start @@ -190,6 +192,7 @@ async def async_h3_query( timeout: float, post: bool, create_protocol=HttpClient, + additional_headers: Optional[Dict] = None, ) -> None: url = baseurl @@ -212,6 +215,7 @@ async def async_h3_query( data=query.to_wire() if post else None, include=False, output_dir=None, + additional_headers=additional_headers, ) return answer @@ -219,7 +223,7 @@ async def async_h3_query( return e -def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False): +def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None): configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True) if verify: configuration.load_verify_locations(verify) @@ -232,7 +236,8 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname query=query, timeout=timeout, create_protocol=HttpClient, - post=post + post=post, + additional_headers=additional_headers ) ) diff --git a/regression-tests.dnsdist/test_DOH3.py b/regression-tests.dnsdist/test_DOH3.py index 9634c914c7..d1a63552ff 100644 --- a/regression-tests.dnsdist/test_DOH3.py +++ b/regression-tests.dnsdist/test_DOH3.py @@ -20,11 +20,31 @@ class TestDOH3(QUICTests, DNSDistTest): addAction("drop.doq.tests.powerdns.com.", DropAction()) addAction("refused.doq.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED)) addAction("spoof.doq.tests.powerdns.com.", SpoofAction("1.2.3.4")) + addAction(HTTPHeaderRule("X-PowerDNS", "^[a]{5}$"), SpoofAction("2.3.4.5")) + addAction(HTTPPathRule("/PowerDNS"), SpoofAction("3.4.5.6")) + addAction(HTTPPathRegexRule("^/PowerDNS-[0-9]"), SpoofAction("6.7.8.9")) addAction("no-backend.doq.tests.powerdns.com.", PoolAction('this-pool-has-no-backend')) + function dohHandler(dq) + if dq:getHTTPScheme() == 'https' and dq:getHTTPHost() == '%s:%d' and dq:getHTTPPath() == '/' and dq:getHTTPQueryString() == '' then + local foundct = false + for key,value in pairs(dq:getHTTPHeaders()) do + if key == 'content-type' and value == 'application/dns-message' then + foundct = true + break + end + end + if foundct then + return DNSAction.Spoof, "10.11.12.13" + end + end + return DNSAction.None + end + addAction("http-lua.doh3.tests.powerdns.com.", LuaAction(dohHandler)) + addDOH3Local("127.0.0.1:%d", "%s", "%s", {keyLogFile='/tmp/keys'}) """ - _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] + _config_params = ['_testServerPort', '_serverName', '_doqServerPort', '_doqServerPort','_serverCert', '_serverKey'] _verboseMode = True def getQUICConnection(self): @@ -33,6 +53,137 @@ class TestDOH3(QUICTests, DNSDistTest): def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection) + def testHeaderRule(self): + """ + DOH3: HeaderRule + """ + name = 'header-rule.doh3.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 0 + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '2.3.4.5') + expectedResponse.answer.append(rrset) + + # this header should match + (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query=query, response=None, useQueue=False, caFile=self._caCert, customHeaders={'x-powerdnS': 'aaaaa'}) + self.assertEqual(receivedResponse, expectedResponse) + + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.flags &= ~dns.flags.RD + expectedQuery.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + # this content of the header should NOT match + (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, customHeaders={'x-powerdnS': 'bbbbb'}) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryNoEDNS(expectedQuery, receivedQuery) + self.assertEqual(response, receivedResponse) + + def testHTTPPath(self): + """ + DOH3: HTTPPath + """ + name = 'http-path.doh3.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 0 + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '3.4.5.6') + expectedResponse.answer.append(rrset) + + # this path should match + (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + 'PowerDNS', caFile=self._caCert, query=query, response=None, useQueue=False) + self.assertEqual(receivedResponse, expectedResponse) + + expectedQuery = dns.message.make_query(name, 'A', 'IN') + expectedQuery.id = 0 + expectedQuery.flags &= ~dns.flags.RD + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + # this path should NOT match + (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + "PowerDNS2", query, response=response, caFile=self._caCert) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryNoEDNS(expectedQuery, receivedQuery) + self.assertEqual(response, receivedResponse) + + def testHTTPPathRegex(self): + """ + DOH3: HTTPPathRegex + """ + name = 'http-path-regex.doh3.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 0 + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '6.7.8.9') + expectedResponse.answer.append(rrset) + + # this path should match + (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + 'PowerDNS-999', caFile=self._caCert, query=query, response=None, useQueue=False) + self.assertEqual(receivedResponse, expectedResponse) + + expectedQuery = dns.message.make_query(name, 'A', 'IN') + expectedQuery.id = 0 + expectedQuery.flags &= ~dns.flags.RD + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + # this path should NOT match + (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + "PowerDNS2", query, response=response, caFile=self._caCert) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryNoEDNS(expectedQuery, receivedQuery) + self.assertEqual(response, receivedResponse) + + def testHTTPLuaBindings(self): + """ + DOH3: Lua HTTP bindings + """ + name = 'http-lua.doh3.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + + (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, post=True) + self.assertTrue(receivedResponse) + class TestDOH3ACL(QUICACLTests, DNSDistTest): _serverKey = 'server.key' _serverCert = 'server.chain'