]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add DoH3 headers, query string, path and scheme bindings
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 30 Dec 2024 14:51:01 +0000 (15:51 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 30 Dec 2024 15:01:58 +0000 (16:01 +0100)
The DoH ones have been there for a long time, but the DoH3 ones were
missing. Note that we still don't have the ability to set a HTTP
response for DoH3 queries (including response maps) and SNI is still
missing (Quiche does not make that last one easy).

pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/doh3.cc
pdns/dnsdistdist/doh3.hh
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/doh3client.py
regression-tests.dnsdist/test_DOH3.py

index 4c9f650ce36f3af4e1f1f2b51f34291c2dc99eaa..2e445dff4082c93ac033ef40e729153d1dbc1930 100644 (file)
@@ -480,12 +480,15 @@ endif
 
 if HAVE_DNS_OVER_HTTP3
 dnsdist_SOURCES += doh3.cc
+testrunner_SOURCES += doh3.cc
 endif
 
 if HAVE_QUICHE
 AM_CPPFLAGS += $(QUICHE_CFLAGS)
 dnsdist_LDADD += $(QUICHE_LDFLAGS) $(QUICHE_LIBS)
 dnsdist_SOURCES += doq-common.cc
+testrunner_SOURCES += doq-common.cc
+testrunner_LDADD += $(QUICHE_LDFLAGS) $(QUICHE_LIBS)
 endif
 
 if !HAVE_LUA_HPP
index 4ff9ff2900e3c47b46d1dfe484244442a0c7555c..d6f48b9390d5e1bb7644fdf33cfcbb4c37b0aa4c 100644 (file)
@@ -32,15 +32,23 @@ HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& reg
 
 bool HTTPHeaderRule::matches(const DNSQuestion* dq) const
 {
-  if (!dq->ids.du) {
+  if (dq->ids.du) {
+    const auto& headers = dq->ids.du->getHTTPHeaders();
+    for (const auto& header : headers) {
+      if (header.first == d_header) {
+        return d_regex.match(header.second);
+      }
+    }
     return false;
   }
-
-  const auto& headers = dq->ids.du->getHTTPHeaders();
-  for (const auto& header : headers) {
-    if (header.first == d_header) {
-      return d_regex.match(header.second);
+  if (dq->ids.doh3u) {
+    const auto& headers = dq->ids.doh3u->getHTTPHeaders();
+    for (const auto& header : headers) {
+      if (header.first == d_header) {
+        return d_regex.match(header.second);
+      }
     }
+    return false;
   }
   return false;
 }
@@ -57,12 +65,14 @@ HTTPPathRule::HTTPPathRule(std::string path) :
 
 bool HTTPPathRule::matches(const DNSQuestion* dq) const
 {
-  if (!dq->ids.du) {
-    return false;
+  if (dq->ids.du) {
+    const auto path = dq->ids.du->getHTTPPath();
+    return d_path == path;
   }
-
-  const auto path = dq->ids.du->getHTTPPath();
-  return d_path == path;
+  else if (dq->ids.doh3u) {
+    return dq->ids.doh3u->getHTTPPath() == d_path;
+  }
+  return false;
 }
 
 string HTTPPathRule::toString() const
@@ -77,11 +87,14 @@ HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex) :
 
 bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const
 {
-  if (!dq->ids.du) {
-    return false;
+  if (dq->ids.du) {
+    const auto path = dq->ids.du->getHTTPPath();
+    return d_regex.match(path);
   }
-
-  return d_regex.match(dq->ids.du->getHTTPPath());
+  else if (dq->ids.doh3u) {
+    return d_regex.match(dq->ids.doh3u->getHTTPPath());
+  }
+  return false;
 }
 
 string HTTPPathRegexRule::toString() const
index 3527ce7ba2800966fbd05b4b7473ab04693d0d53..58546e1029f4d02f9db904cdc9267ea9e61a62ff 100644 (file)
@@ -513,38 +513,53 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
 
 #ifdef HAVE_DNS_OVER_HTTPS
   luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPPath", [](const DNSQuestion& dnsQuestion) {
-    if (dnsQuestion.ids.du == nullptr) {
-      return std::string();
+    if (dnsQuestion.ids.du) {
+      return dnsQuestion.ids.du->getHTTPPath();
+    }
+    if (dnsQuestion.ids.doh3u) {
+      return dnsQuestion.ids.doh3u->getHTTPPath();
     }
-    return dnsQuestion.ids.du->getHTTPPath();
+    return std::string();
   });
 
   luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPQueryString", [](const DNSQuestion& dnsQuestion) {
-    if (dnsQuestion.ids.du == nullptr) {
-      return std::string();
+    if (dnsQuestion.ids.du) {
+      return dnsQuestion.ids.du->getHTTPQueryString();
     }
-    return dnsQuestion.ids.du->getHTTPQueryString();
+    if (dnsQuestion.ids.doh3u) {
+      return dnsQuestion.ids.doh3u->getHTTPQueryString();
+    }
+    return std::string();
   });
 
   luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPHost", [](const DNSQuestion& dnsQuestion) {
-    if (dnsQuestion.ids.du == nullptr) {
-      return std::string();
+    if (dnsQuestion.ids.du) {
+      return dnsQuestion.ids.du->getHTTPHost();
     }
-    return dnsQuestion.ids.du->getHTTPHost();
+    if (dnsQuestion.ids.doh3u) {
+      return dnsQuestion.ids.doh3u->getHTTPHost();
+    }
+    return std::string();
   });
 
   luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPScheme", [](const DNSQuestion& dnsQuestion) {
-    if (dnsQuestion.ids.du == nullptr) {
-      return std::string();
+    if (dnsQuestion.ids.du) {
+      return dnsQuestion.ids.du->getHTTPScheme();
+    }
+    if (dnsQuestion.ids.doh3u) {
+      return dnsQuestion.ids.doh3u->getHTTPScheme();
     }
-    return dnsQuestion.ids.du->getHTTPScheme();
+    return std::string();
   });
 
   luaCtx.registerFunction<LuaAssociativeTable<std::string> (DNSQuestion::*)(void) const>("getHTTPHeaders", [](const DNSQuestion& dnsQuestion) {
-    if (dnsQuestion.ids.du == nullptr) {
-      return LuaAssociativeTable<std::string>();
+    if (dnsQuestion.ids.du) {
+      return dnsQuestion.ids.du->getHTTPHeaders();
+    }
+    if (dnsQuestion.ids.doh3u) {
+      return dnsQuestion.ids.doh3u->getHTTPHeaders();
     }
-    return dnsQuestion.ids.du->getHTTPHeaders();
+    return LuaAssociativeTable<std::string>();
   });
 
   luaCtx.registerFunction<void (DNSQuestion::*)(uint64_t statusCode, const std::string& body, const boost::optional<std::string> contentType)>("setHTTPResponse", [](DNSQuestion& dnsQuestion, uint64_t statusCode, const std::string& body, const boost::optional<std::string>& contentType) {
index 03ffb98a260226a8ec79248ade20bab936445a32..d531cb95dd5610b3e96ec0d233923156954f6ee5 100644 (file)
@@ -294,12 +294,16 @@ size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq,
 const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq)
 {
   if (!dq->httpPath) {
-    if (dq->dq->ids.du == nullptr) {
-      return nullptr;
-    }
-#ifdef HAVE_DNS_OVER_HTTPS
-    dq->httpPath = dq->dq->ids.du->getHTTPPath();
+    if (dq->dq->ids.du) {
+#if defined(HAVE_DNS_OVER_HTTPS)
+      dq->httpPath = dq->dq->ids.du->getHTTPPath();
 #endif /* HAVE_DNS_OVER_HTTPS */
+    }
+    else if (dq->dq->ids.doh3u) {
+#if defined(HAVE_DNS_OVER_HTTP3)
+      dq->httpPath = dq->dq->ids.doh3u->getHTTPPath();
+#endif /* HAVE_DNS_OVER_HTTP3 */
+    }
   }
   if (dq->httpPath) {
     return dq->httpPath->c_str();
@@ -310,12 +314,16 @@ const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq)
 const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq)
 {
   if (!dq->httpQueryString) {
-    if (dq->dq->ids.du == nullptr) {
-      return nullptr;
-    }
+    if (dq->dq->ids.du) {
 #ifdef HAVE_DNS_OVER_HTTPS
-    dq->httpQueryString = dq->dq->ids.du->getHTTPQueryString();
+      dq->httpQueryString = dq->dq->ids.du->getHTTPQueryString();
 #endif /* HAVE_DNS_OVER_HTTPS */
+    }
+    else if (dq->dq->ids.doh3u) {
+#if defined(HAVE_DNS_OVER_HTTP3)
+      dq->httpQueryString = dq->dq->ids.doh3u->getHTTPQueryString();
+#endif /* HAVE_DNS_OVER_HTTP3 */
+    }
   }
   if (dq->httpQueryString) {
     return dq->httpQueryString->c_str();
@@ -326,12 +334,16 @@ const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestio
 const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq)
 {
   if (!dq->httpHost) {
-    if (dq->dq->ids.du == nullptr) {
-      return nullptr;
-    }
+    if (dq->dq->ids.du) {
 #ifdef HAVE_DNS_OVER_HTTPS
-    dq->httpHost = dq->dq->ids.du->getHTTPHost();
+      dq->httpHost = dq->dq->ids.du->getHTTPHost();
 #endif /* HAVE_DNS_OVER_HTTPS */
+    }
+    else if (dq->dq->ids.doh3u) {
+#if defined(HAVE_DNS_OVER_HTTP3)
+      dq->httpHost = dq->dq->ids.doh3u->getHTTPHost();
+#endif /* HAVE_DNS_OVER_HTTP3 */
+    }
   }
   if (dq->httpHost) {
     return dq->httpHost->c_str();
@@ -342,12 +354,16 @@ const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq)
 const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq)
 {
   if (!dq->httpScheme) {
-    if (dq->dq->ids.du == nullptr) {
-      return nullptr;
-    }
+    if (dq->dq->ids.du) {
 #ifdef HAVE_DNS_OVER_HTTPS
-    dq->httpScheme = dq->dq->ids.du->getHTTPScheme();
+      dq->httpScheme = dq->dq->ids.du->getHTTPScheme();
 #endif /* HAVE_DNS_OVER_HTTPS */
+    }
+    else if (dq->dq->ids.doh3u) {
+#if defined(HAVE_DNS_OVER_HTTP3)
+      dq->httpScheme = dq->dq->ids.doh3u->getHTTPScheme();
+#endif /* HAVE_DNS_OVER_HTTP3 */
+    }
   }
   if (dq->httpScheme) {
     return dq->httpScheme->c_str();
@@ -404,36 +420,45 @@ size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* dq, c
 
 size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_http_header_t** out)
 {
-  if (dq->dq->ids.du == nullptr) {
-    return 0;
-  }
+#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
+  const auto processHeaders = [&dq](const std::unordered_map<std::string, std::string>& headers) {
+    if (headers.size() == 0) {
+      return;
+    }
+    dq->httpHeaders = std::make_unique<std::unordered_map<std::string, std::string>>(std::move(headers));
+    if (!dq->httpHeadersVect) {
+      dq->httpHeadersVect = std::make_unique<std::vector<dnsdist_ffi_http_header_t>>();
+    }
+    dq->httpHeadersVect->clear();
+    dq->httpHeadersVect->resize(dq->httpHeaders->size());
+    size_t pos = 0;
+    for (const auto& header : *dq->httpHeaders) {
+      dq->httpHeadersVect->at(pos).name = header.first.c_str();
+      dq->httpHeadersVect->at(pos).value = header.second.c_str();
+      ++pos;
+    }
+  };
 
-#ifdef HAVE_DNS_OVER_HTTPS
-  auto headers = dq->dq->ids.du->getHTTPHeaders();
-  if (headers.size() == 0) {
-    return 0;
+#if defined(HAVE_DNS_OVER_HTTPS)
+  if (dq->dq->ids.du) {
+    const auto& headers = dq->dq->ids.du->getHTTPHeaders();
+    processHeaders(headers);
   }
-  dq->httpHeaders = std::make_unique<std::unordered_map<std::string, std::string>>(std::move(headers));
-  if (!dq->httpHeadersVect) {
-    dq->httpHeadersVect = std::make_unique<std::vector<dnsdist_ffi_http_header_t>>();
-  }
-  dq->httpHeadersVect->clear();
-  dq->httpHeadersVect->resize(dq->httpHeaders->size());
-  size_t pos = 0;
-  for (const auto& header : *dq->httpHeaders) {
-    dq->httpHeadersVect->at(pos).name = header.first.c_str();
-    dq->httpHeadersVect->at(pos).value = header.second.c_str();
-    ++pos;
+#endif /* HAVE_DNS_OVER_HTTPS */
+#if defined(HAVE_DNS_OVER_HTTP3)
+  if (dq->dq->ids.doh3u) {
+    const auto& headers = dq->dq->ids.doh3u->getHTTPHeaders();
+    processHeaders(headers);
   }
+#endif /* HAVE_DNS_OVER_HTTP3 */
 
   if (!dq->httpHeadersVect->empty()) {
     *out = dq->httpHeadersVect->data();
   }
-
   return dq->httpHeadersVect->size();
-#else
+#else /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */
   return 0;
-#endif
+#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */
 }
 
 size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out)
index 661e9c61822366d885c376138cfdf3025a846000..e98880507398e239d391b4b822171789a4196a52 100644 (file)
@@ -49,8 +49,6 @@
 
 using namespace dnsdist::doq;
 
-using h3_headers_t = std::map<std::string, std::string>;
-
 class H3Connection
 {
 public:
@@ -70,7 +68,7 @@ public:
   QuicheConfig d_config;
   QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
   // buffer request headers by streamID
-  std::unordered_map<uint64_t, h3_headers_t> d_headersBuffers;
+  std::unordered_map<uint64_t, dnsdist::doh3::h3_headers_t> d_headersBuffers;
   std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
   std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
 };
@@ -629,7 +627,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
   }
 }
 
-static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
+static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers)
 {
   try {
     auto unit = std::make_unique<DOH3Unit>(std::move(query));
@@ -639,6 +637,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con
     unit->ids.protocol = dnsdist::Protocol::DoH3;
     unit->serverConnID = serverConnID;
     unit->streamID = streamID;
+    unit->headers = std::move(headers);
 
     processDOH3Query(std::move(unit));
   }
@@ -706,7 +705,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
       std::string_view content(reinterpret_cast<char*>(value), value_len);
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
-      auto* headersptr = reinterpret_cast<h3_headers_t*>(argp);
+      auto* headersptr = reinterpret_cast<dnsdist::doh3::h3_headers_t*>(argp);
       headersptr->emplace(key, content);
       return 0;
     },
@@ -739,7 +738,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
       return;
     }
     DEBUGLOG("Dispatching GET query");
-    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID);
+    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
     conn.d_streamBuffers.erase(streamID);
     conn.d_headersBuffers.erase(streamID);
     return;
@@ -804,7 +803,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
   }
 
   DEBUGLOG("Dispatching POST query");
-  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
+  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
   conn.d_headersBuffers.erase(streamID);
   conn.d_streamBuffers.erase(streamID);
 }
@@ -821,7 +820,7 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3
     if (streamID < 0) {
       break;
     }
-    conn.d_headersBuffers.try_emplace(streamID, h3_headers_t{});
+    conn.d_headersBuffers.try_emplace(streamID, dnsdist::doh3::h3_headers_t{});
 
     switch (quiche_h3_event_type(event)) {
     case QUICHE_H3_EVENT_HEADERS: {
@@ -1035,4 +1034,76 @@ void doh3Thread(ClientState* clientState)
   }
 }
 
+std::string DOH3Unit::getHTTPPath() const
+{
+  const auto& path = headers.at(":path");
+  auto pos = path.find('?');
+  if (pos == string::npos) {
+    return path;
+  }
+  return path.substr(0, pos);
+}
+
+std::string DOH3Unit::getHTTPQueryString() const
+{
+  const auto& path = headers.at(":path");
+  auto pos = path.find('?');
+  if (pos == string::npos) {
+    return std::string();
+  }
+
+  return path.substr(pos);
+}
+
+std::string DOH3Unit::getHTTPHost() const
+{
+  const auto& host = headers.find(":authority");
+  if (host == headers.end()) {
+    return {};
+  }
+  return host->second;
+}
+
+std::string DOH3Unit::getHTTPScheme() const
+{
+  const auto& scheme = headers.find(":scheme");
+  if (scheme == headers.end()) {
+    return {};
+  }
+  return scheme->second;
+}
+
+const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const
+{
+  return headers;
+}
+
+#else /* HAVE_DNS_OVER_HTTP3 */
+
+std::string DOH3Unit::getHTTPPath() const
+{
+  return std::string();
+}
+
+std::string DOH3Unit::getHTTPQueryString() const
+{
+  return std::string();
+}
+
+std::string DOH3Unit::getHTTPHost() const
+{
+  return std::string();
+}
+
+std::string DOH3Unit::getHTTPScheme() const
+{
+  return std::string();
+}
+
+const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const
+{
+  static const dnsdist::doh3::h3_headers_t headers;
+  return headers;
+}
+
 #endif /* HAVE_DNS_OVER_HTTP3 */
index 954ea4aab2ea3345670cfb2a3001258e98fd69cb..0288ad14452873f6261bc84263de14dd7be7a014 100644 (file)
@@ -22,6 +22,7 @@
 #pragma once
 
 #include <memory>
+#include <string>
 
 #include "config.h"
 #include "channel.hh"
 struct DOH3ServerConfig;
 struct DownstreamState;
 
+namespace dnsdist::doh3
+{
+using h3_headers_t = std::unordered_map<std::string, std::string>;
+}
+
 #ifdef HAVE_DNS_OVER_HTTP3
 
 #include "doq-common.hh"
@@ -78,10 +84,17 @@ struct DOH3Unit
   DOH3Unit(const DOH3Unit&) = delete;
   DOH3Unit& operator=(const DOH3Unit&) = delete;
 
+  [[nodiscard]] std::string getHTTPPath() const;
+  [[nodiscard]] std::string getHTTPQueryString() const;
+  [[nodiscard]] std::string getHTTPHost() const;
+  [[nodiscard]] std::string getHTTPScheme() const;
+  [[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
+
   InternalQueryState ids;
   PacketBuffer query;
   PacketBuffer response;
   PacketBuffer serverConnID;
+  dnsdist::doh3::h3_headers_t headers;
   std::shared_ptr<DownstreamState> downstream{nullptr};
   DOH3ServerConfig* dsc{nullptr};
   uint64_t streamID{0};
@@ -104,6 +117,11 @@ void doh3Thread(ClientState* clientState);
 
 struct DOH3Unit
 {
+  std::string getHTTPPath() const;
+  std::string getHTTPQueryString() const;
+  const std::string& getHTTPHost() const;
+  const std::string& getHTTPScheme() const;
+  const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
 };
 
 struct DOH3Frontend
index bcf6ea10aa21a7839b4a9f286f5fd1b8701f6805..1bf6a66bb88ca0dd26fcdc8a1c04342e4ab2493c 100644 (file)
@@ -55,6 +55,10 @@ void handleResponseSent(const InternalQueryState& ids, double udiff, const Combo
 {
 }
 
+void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol, bool fromBackend)
+{
+}
+
 std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
 
 ProcessQueryResult processQuery(DNSQuestion& dnsQuestion, std::shared_ptr<DownstreamState>& selectedBackend)
index 994b75168ec0a4d583eb9232dd5097cfde08ebe9..3515e75d0426fc3bef3f5d21fd77358006da9893 100644 (file)
@@ -1151,7 +1151,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return (receivedQuery, message)
 
     @classmethod
-    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
+    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None):
 
         if response:
             if toQueue:
@@ -1159,7 +1159,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             else:
                 cls._toResponderQueue.put(response, True, timeout)
 
-        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
+        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders)
 
         receivedQuery = None
 
index 85d66a399a51ce290161643783d820c678e0b7d7..c1a1ae784ddba7027ad1f0d198b9b4dc2e9ac58c 100644 (file)
@@ -1,4 +1,5 @@
 import base64
+import copy
 import asyncio
 import pickle
 import ssl
@@ -133,7 +134,7 @@ class HttpClient(QuicConnectionProtocol):
                 (b":authority", request.url.authority.encode()),
                 (b":path", request.url.full_path.encode()),
             ]
-            + [(k.encode(), v.encode()) for (k, v) in request.headers.items()],
+            + [(k.lower().encode(), v.encode()) for (k, v) in request.headers.items()],
             end_stream=not request.content,
         )
         if request.content:
@@ -155,21 +156,22 @@ async def perform_http_request(
     data: Optional[bytes],
     include: bool,
     output_dir: Optional[str],
+    additional_headers: Optional[Dict] = None,
 ) -> None:
     # perform request
     start = time.time()
     if data is not None:
+        headers = copy.deepcopy(additional_headers) if additional_headers else {}
+        headers["content-length"] = str(len(data))
+        headers["content-type"] = "application/dns-message"
         http_events = await client.post(
             url,
             data=data,
-            headers={
-                "content-length": str(len(data)),
-                "content-type": "application/dns-message",
-            },
+            headers=headers,
         )
         method = "POST"
     else:
-        http_events = await client.get(url)
+        http_events = await client.get(url, headers=additional_headers)
         method = "GET"
     elapsed = time.time() - start
 
@@ -190,6 +192,7 @@ async def async_h3_query(
     timeout: float,
     post: bool,
     create_protocol=HttpClient,
+    additional_headers: Optional[Dict] = None,
 ) -> None:
 
     url = baseurl
@@ -212,6 +215,7 @@ async def async_h3_query(
                     data=query.to_wire() if post else None,
                     include=False,
                     output_dir=None,
+                    additional_headers=additional_headers,
                 )
 
                 return answer
@@ -219,7 +223,7 @@ async def async_h3_query(
             return e
 
 
-def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False):
+def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None):
     configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
     if verify:
         configuration.load_verify_locations(verify)
@@ -232,7 +236,8 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname
             query=query,
             timeout=timeout,
             create_protocol=HttpClient,
-            post=post
+            post=post,
+            additional_headers=additional_headers
         )
     )
 
index 9634c914c7d8d2acecd7523abb15ad0aed9b3943..d1a63552ff9a76abccd8f218cd1a926e406a7526 100644 (file)
@@ -20,11 +20,31 @@ class TestDOH3(QUICTests, DNSDistTest):
     addAction("drop.doq.tests.powerdns.com.", DropAction())
     addAction("refused.doq.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED))
     addAction("spoof.doq.tests.powerdns.com.", SpoofAction("1.2.3.4"))
+    addAction(HTTPHeaderRule("X-PowerDNS", "^[a]{5}$"), SpoofAction("2.3.4.5"))
+    addAction(HTTPPathRule("/PowerDNS"), SpoofAction("3.4.5.6"))
+    addAction(HTTPPathRegexRule("^/PowerDNS-[0-9]"), SpoofAction("6.7.8.9"))
     addAction("no-backend.doq.tests.powerdns.com.", PoolAction('this-pool-has-no-backend'))
 
+    function dohHandler(dq)
+      if dq:getHTTPScheme() == 'https' and dq:getHTTPHost() == '%s:%d' and dq:getHTTPPath() == '/' and dq:getHTTPQueryString() == '' then
+        local foundct = false
+        for key,value in pairs(dq:getHTTPHeaders()) do
+          if key == 'content-type' and value == 'application/dns-message' then
+            foundct = true
+            break
+          end
+        end
+        if foundct then
+          return DNSAction.Spoof, "10.11.12.13"
+        end
+      end
+      return DNSAction.None
+    end
+    addAction("http-lua.doh3.tests.powerdns.com.", LuaAction(dohHandler))
+
     addDOH3Local("127.0.0.1:%d", "%s", "%s", {keyLogFile='/tmp/keys'})
     """
-    _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey']
+    _config_params = ['_testServerPort',  '_serverName', '_doqServerPort', '_doqServerPort','_serverCert', '_serverKey']
     _verboseMode = True
 
     def getQUICConnection(self):
@@ -33,6 +53,137 @@ class TestDOH3(QUICTests, DNSDistTest):
     def sendQUICQuery(self, query, response=None, useQueue=True, connection=None):
         return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection)
 
+    def testHeaderRule(self):
+        """
+        DOH3: HeaderRule
+        """
+        name = 'header-rule.doh3.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.id = 0
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '2.3.4.5')
+        expectedResponse.answer.append(rrset)
+
+        # this header should match
+        (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query=query, response=None, useQueue=False, caFile=self._caCert, customHeaders={'x-powerdnS': 'aaaaa'})
+        self.assertEqual(receivedResponse, expectedResponse)
+
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.flags &= ~dns.flags.RD
+        expectedQuery.id = 0
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        # this content of the header should NOT match
+        (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, customHeaders={'x-powerdnS': 'bbbbb'})
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryNoEDNS(expectedQuery, receivedQuery)
+        self.assertEqual(response, receivedResponse)
+
+    def testHTTPPath(self):
+        """
+        DOH3: HTTPPath
+        """
+        name = 'http-path.doh3.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.id = 0
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '3.4.5.6')
+        expectedResponse.answer.append(rrset)
+
+        # this path should match
+        (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + 'PowerDNS', caFile=self._caCert, query=query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, expectedResponse)
+
+        expectedQuery = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery.id = 0
+        expectedQuery.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        # this path should NOT match
+        (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + "PowerDNS2", query, response=response, caFile=self._caCert)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryNoEDNS(expectedQuery, receivedQuery)
+        self.assertEqual(response, receivedResponse)
+
+    def testHTTPPathRegex(self):
+        """
+        DOH3: HTTPPathRegex
+        """
+        name = 'http-path-regex.doh3.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.id = 0
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '6.7.8.9')
+        expectedResponse.answer.append(rrset)
+
+        # this path should match
+        (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + 'PowerDNS-999', caFile=self._caCert, query=query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, expectedResponse)
+
+        expectedQuery = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery.id = 0
+        expectedQuery.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        # this path should NOT match
+        (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL + "PowerDNS2", query, response=response, caFile=self._caCert)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryNoEDNS(expectedQuery, receivedQuery)
+        self.assertEqual(response, receivedResponse)
+
+    def testHTTPLuaBindings(self):
+        """
+        DOH3: Lua HTTP bindings
+        """
+        name = 'http-lua.doh3.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.id = 0
+
+        (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, post=True)
+        self.assertTrue(receivedResponse)
+
 class TestDOH3ACL(QUICACLTests, DNSDistTest):
     _serverKey = 'server.key'
     _serverCert = 'server.chain'