]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: doh3 add support for post queries
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Mon, 4 Dec 2023 14:58:03 +0000 (15:58 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 8 Dec 2023 08:14:41 +0000 (09:14 +0100)
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/doh3.cc
pdns/dnsdistdist/doq-common.cc
pdns/dnsdistdist/doq-common.hh
pdns/dnsdistdist/doq.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/doh3client.py
regression-tests.dnsdist/test_DOH3.py

index 8d3334d35d6a2a3f3453edd6f007ceeef3e2fe16..bc597a6611f19c8c019e3695982aa599c22993f9 100644 (file)
@@ -215,8 +215,8 @@ dnsdist_SOURCES = \
        doh.hh \
        doh3.hh \
        dolog.cc dolog.hh \
-       doq.hh \
        doq-common.hh \
+       doq.hh \
        ednscookies.cc ednscookies.hh \
        ednsextendederror.cc ednsextendederror.hh \
        ednsoptions.cc ednsoptions.hh \
index 02a2ea928ca849ea1b0c425368ac579c112f9c68..71f3adbd7a505435a0599255c5756519cc06d530 100644 (file)
 
 #include "doq-common.hh"
 
+#if 0
+#define DEBUGLOG_ENABLED
+#define DEBUGLOG(x) std::cerr << x << std::endl;
+#else
+#define DEBUGLOG(x)
+#endif
+
 using namespace dnsdist::doq;
 
 class H3Connection
@@ -626,7 +633,6 @@ void doh3Thread(ClientState* clientState)
 
     Socket sock(clientState->udpFD);
 
-    PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
     auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
 
     auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
@@ -726,6 +732,7 @@ void doh3Thread(ClientState* clientState)
             DEBUGLOG("Successfully created HTTP/3 connection");
           }
 
+          std::map<std::string, std::string> headers;
           while (true) {
             quiche_h3_event* ev;
             // Processes HTTP/3 data received from the peer
@@ -739,20 +746,17 @@ void doh3Thread(ClientState* clientState)
 
             switch (quiche_h3_event_type(ev)) {
             case QUICHE_H3_EVENT_HEADERS: {
-              std::string path;
               int rc = quiche_h3_event_for_each_header(
                 ev,
                 [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int {
                   std::string_view key(reinterpret_cast<char*>(name), name_len);
                   std::string_view content(reinterpret_cast<char*>(value), value_len);
-                  if (key == ":path") {
-                    auto pathptr = reinterpret_cast<std::string*>(argp);
-                    *pathptr = content;
-                  }
+                  auto headersptr = reinterpret_cast<std::map<std::string, std::string>*>(argp);
+                  headersptr->emplace(key, content);
                   return 0;
                 },
-                &path);
-              if (rc != 0) {
+                &headers);
+              if (rc != 0 || !headers.count(":method")) {
                 DEBUGLOG("Failed to process headers");
                 ++dnsdist::metrics::g_stats.nonCompliantQueries;
                 ++clientState->nonCompliantQueries;
@@ -760,63 +764,20 @@ void doh3Thread(ClientState* clientState)
                 h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to process query headers");
                 break;
               }
-              if (path.empty()) {
-                DEBUGLOG("Path not found");
-                ++dnsdist::metrics::g_stats.nonCompliantQueries;
-                ++clientState->nonCompliantQueries;
-                ++frontend->d_errorResponses;
-                h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found");
-                break;
-              }
-              {
-                auto pos = path.find("?dns=");
-                if (pos == string::npos) {
-                  pos = path.find("&dns=");
+              if (headers.at(":method") == "GET") {
+                if (!headers.count(":path") || headers.at(":path").empty()) {
+                  DEBUGLOG("Path not found");
+                  ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                  ++clientState->nonCompliantQueries;
+                  ++frontend->d_errorResponses;
+                  h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found");
+                  break;
                 }
-                if (pos != string::npos) {
-                  // need to base64url decode this
-                  string sdns(path.substr(pos + 5));
-                  boost::replace_all(sdns, "-", "+");
-                  boost::replace_all(sdns, "_", "/");
-                  // re-add padding that may have been missing
-                  switch (sdns.size() % 4) {
-                  case 2:
-                    sdns.append(2, '=');
-                    break;
-                  case 3:
-                    sdns.append(1, '=');
-                    break;
-                  }
-
-                  PacketBuffer decoded;
-
-                  /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */
-                  const size_t maxAdditionalSizeForEDNS = 35U;
-                  /* rough estimate so we hopefully don't need a new allocation later */
-                  /* We reserve at few additional bytes to be able to add EDNS later */
-                  const size_t estimate = ((sdns.size() * 3) / 4);
-                  decoded.reserve(estimate + maxAdditionalSizeForEDNS);
-                  if (B64Decode(sdns, decoded) < 0) {
-                    DEBUGLOG("Unable to base64 decode()");
-                    ++dnsdist::metrics::g_stats.nonCompliantQueries;
-                    ++clientState->nonCompliantQueries;
-                    ++frontend->d_errorResponses;
-                    h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL");
-                    break;
-                  }
-
-                  if (decoded.size() < sizeof(dnsheader)) {
-                    ++dnsdist::metrics::g_stats.nonCompliantQueries;
-                    ++clientState->nonCompliantQueries;
-                    ++frontend->d_errorResponses;
-                    h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query");
-                    break;
-                  }
-                  DEBUGLOG("Dispatching query");
-                  doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID);
-                  conn->get().d_streamBuffers.erase(streamID);
+                auto pos = headers.at(":path").find("?dns=");
+                if (pos == string::npos) {
+                  pos = headers.at(":path").find("&dns=");
                 }
-                else {
+                if (pos == string::npos) {
                   DEBUGLOG("User error, unable to find the DNS parameter");
                   ++dnsdist::metrics::g_stats.nonCompliantQueries;
                   ++clientState->nonCompliantQueries;
@@ -824,11 +785,99 @@ void doh3Thread(ClientState* clientState)
                   h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to find the DNS parameter");
                   break;
                 }
+                // need to base64url decode this
+                string sdns(headers.at(":path").substr(pos + 5));
+                boost::replace_all(sdns, "-", "+");
+                boost::replace_all(sdns, "_", "/");
+                // re-add padding that may have been missing
+                switch (sdns.size() % 4) {
+                case 2:
+                  sdns.append(2, '=');
+                  break;
+                case 3:
+                  sdns.append(1, '=');
+                  break;
+                }
+
+                PacketBuffer decoded;
+                /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */
+                const size_t maxAdditionalSizeForEDNS = 35U;
+                /* rough estimate so we hopefully don't need a new allocation later */
+                /* We reserve at few additional bytes to be able to add EDNS later */
+                const size_t estimate = ((sdns.size() * 3) / 4);
+                decoded.reserve(estimate + maxAdditionalSizeForEDNS);
+                if (B64Decode(sdns, decoded) < 0) {
+                  DEBUGLOG("Unable to base64 decode()");
+                  ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                  ++clientState->nonCompliantQueries;
+                  ++frontend->d_errorResponses;
+                  h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL");
+                  break;
+                }
+                if (decoded.size() < sizeof(dnsheader)) {
+                  ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                  ++clientState->nonCompliantQueries;
+                  ++frontend->d_errorResponses;
+                  h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query");
+                  break;
+                }
+                DEBUGLOG("Dispatching GET query");
+                doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID);
+                conn->get().d_streamBuffers.erase(streamID);
+              }
+              else if (headers.at(":method") == "POST") {
+                if (!quiche_h3_event_headers_has_body(ev)) {
+                  DEBUGLOG("Empty POST query");
+                  ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                  ++clientState->nonCompliantQueries;
+                  ++frontend->d_errorResponses;
+                  h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Empty POST query");
+                  break;
+                }
+              }
+              else {
+                DEBUGLOG("Unsupported HTTP method");
+                ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                ++clientState->nonCompliantQueries;
+                ++frontend->d_errorResponses;
+                h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported HTTP method");
+                break;
               }
               break;
             }
+            case QUICHE_H3_EVENT_DATA: {
+              if (!headers.count("content-type") || headers.at("content-type") != "application/dns-message") {
+                DEBUGLOG("Unsupported content-type");
+                ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                ++clientState->nonCompliantQueries;
+                ++frontend->d_errorResponses;
+                h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported content-type");
+                break;
+              }
+              PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
+              PacketBuffer decoded;
+
+              while (true) {
+                ssize_t len = quiche_h3_recv_body(conn->get().d_http3.get(),
+                                                  conn->get().d_conn.get(), streamID,
+                                                  buffer.data(), buffer.capacity());
 
-            case QUICHE_H3_EVENT_DATA:
+                if (len <= 0) {
+                  break;
+                }
+                decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
+              }
+              if (decoded.size() < sizeof(dnsheader)) {
+                ++dnsdist::metrics::g_stats.nonCompliantQueries;
+                ++clientState->nonCompliantQueries;
+                ++frontend->d_errorResponses;
+                h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query");
+                break;
+              }
+              DEBUGLOG("Dispatching POST query");
+              doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID);
+              conn->get().d_streamBuffers.erase(streamID);
+            }
             case QUICHE_H3_EVENT_FINISHED:
             case QUICHE_H3_EVENT_RESET:
             case QUICHE_H3_EVENT_PRIORITY_UPDATE:
index 211928aee55bd56e1195e5a7f1ca4b109a4e78d5..4b0b2868f935c3387e2039a6149ef5b5b866913d 100644 (file)
 
 #ifdef HAVE_DNS_OVER_QUIC
 
+#if 0
+#define DEBUGLOG_ENABLED
+#define DEBUGLOG(x) std::cerr << x << std::endl;
+#else
+#define DEBUGLOG(x)
+#endif
+
 namespace dnsdist::doq
 {
 
index c8064a86069e95915f6798485d8d21373c097b4e..7af19ecb79540d4864832dcd5bdf8af5902fe917 100644 (file)
 #include "sstuff.hh"
 #include "libssl.hh"
 
-#if 0
-#define DEBUGLOG_ENABLED
-#define DEBUGLOG(x) std::cerr << x << std::endl;
-#else
-#define DEBUGLOG(x)
-#endif
-
 namespace dnsdist::doq
 {
 
index 7037349652c91adc45f69db38095bd4aa84ccc9b..66a71b8504e96c3948fed1042b700f8fd691ea68 100644 (file)
 
 using namespace dnsdist::doq;
 
+#if 0
+#define DEBUGLOG_ENABLED
+#define DEBUGLOG(x) std::cerr << x << std::endl;
+#else
+#define DEBUGLOG(x)
+#endif
+
 class Connection
 {
 public:
@@ -576,7 +583,6 @@ void doqThread(ClientState* clientState)
 
     Socket sock(clientState->udpFD);
 
-    PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
     auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
 
     auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
index 7b1cca35304f61211a179b81efaebf8b47823fd4..6dff1318ebb81e68ae820ee6063d662d96b0b203 100644 (file)
@@ -1129,7 +1129,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return (receivedQuery, message)
 
     @classmethod
-    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
+    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
 
         if response:
             if toQueue:
@@ -1137,7 +1137,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             else:
                 cls._toResponderQueue.put(response, True, timeout)
 
-        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName)
+        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
 
         receivedQuery = None
 
index c04b06b152a435fd60fa03337613c999f2244783..85d66a399a51ce290161643783d820c678e0b7d7 100644 (file)
@@ -24,53 +24,10 @@ from aioquic.h3.events import (
 )
 from aioquic.quic.configuration import QuicConfiguration
 from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
-#from aioquic.quic.logger import QuicFileLogger
 from aioquic.tls import CipherSuite, SessionTicket
 
 from doqclient import StreamResetError
-#
-#class DnsClientProtocol(QuicConnectionProtocol):
-#    def __init__(self, *args, **kwargs):
-#        super().__init__(*args, **kwargs)
-#        self._ack_waiter: Any = None
-#
-#    def pack(self, data):
-#        # serialize query
-#        data = bytes(data)
-#        data = struct.pack("!H", len(data)) + data
-#        return data
-#
-#    async def query(self, query: dns.message) -> None:
-#        data = self.pack(query.to_wire())
-#        # send query and wait for answer
-#        stream_id = self._quic.get_next_available_stream_id()
-#        self._quic.send_stream_data(stream_id, data, end_stream=True)
-#        waiter = self._loop.create_future()
-#        self._ack_waiter = waiter
-#        self.transmit()
-#
-#        return await asyncio.shield(waiter)
-#
-#    def quic_event_received(self, event: QuicEvent) -> None:
-#        if self._ack_waiter is not None:
-#            if isinstance(event, StreamDataReceived):
-#                length = struct.unpack("!H", bytes(event.data[:2]))[0]
-#                answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
-#
-#                waiter = self._ack_waiter
-#                self._ack_waiter = None
-#                waiter.set_result(answer)
-#            if isinstance(event, StreamReset):
-#                waiter = self._ack_waiter
-#                self._ack_waiter = None
-#                waiter.set_result(event)
-#
-#class BogusDnsClientProtocol(DnsClientProtocol):
-#    def pack(self, data):
-#        # serialize query
-#        data = bytes(data)
-#        data = struct.pack("!H", len(data) * 2) + data
-#        return data
+
 HttpConnection = Union[H0Connection, H3Connection]
 
 class URL:
@@ -195,20 +152,19 @@ class HttpClient(QuicConnectionProtocol):
 async def perform_http_request(
     client: HttpClient,
     url: str,
-    data: Optional[str],
+    data: Optional[bytes],
     include: bool,
     output_dir: Optional[str],
 ) -> None:
     # perform request
     start = time.time()
     if data is not None:
-        data_bytes = data.encode()
         http_events = await client.post(
             url,
-            data=data_bytes,
+            data=data,
             headers={
-                "content-length": str(len(data_bytes)),
-                "content-type": "application/x-www-form-urlencoded",
+                "content-length": str(len(data)),
+                "content-type": "application/dns-message",
             },
         )
         method = "POST"
@@ -232,10 +188,13 @@ async def async_h3_query(
     port: int,
     query: dns.message,
     timeout: float,
-    create_protocol=HttpClient
+    post: bool,
+    create_protocol=HttpClient,
 ) -> None:
 
-    url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
+    url = baseurl
+    if not post:
+        url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
     async with connect(
         "127.0.0.1",
         port,
@@ -250,7 +209,7 @@ async def async_h3_query(
                 answer = await perform_http_request(
                     client=client,
                     url=url,
-                    data=None,
+                    data=query.to_wire() if post else None,
                     include=False,
                     output_dir=None,
                 )
@@ -259,10 +218,12 @@ async def async_h3_query(
         except asyncio.TimeoutError as e:
             return e
 
-def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None):
+
+def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False):
     configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
     if verify:
         configuration.load_verify_locations(verify)
+
     result = asyncio.run(
         async_h3_query(
             configuration=configuration,
@@ -270,7 +231,8 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname
             port=port,
             query=query,
             timeout=timeout,
-            create_protocol=HttpClient
+            create_protocol=HttpClient,
+            post=post
         )
     )
 
index dcff35e096903cd3940757d996340e7849a4fc86..4374bc2a9f4365f6e6b77e03d89530695f7320aa 100644 (file)
@@ -32,3 +32,41 @@ class TestDOH3(QUICTests, DNSDistTest):
 
     def sendQUICQuery(self, query, response=None, useQueue=True, connection=None):
         return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection)
+
+class TestDOH3Specifics(DNSDistTest):
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _doqServerPort = pickAvailablePort()
+    _dohBaseURL = ("https://%s:%d/" % (_serverName, _doqServerPort))
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    addDOH3Local("127.0.0.1:%d", "%s", "%s", {keyLogFile='/tmp/keys'})
+    """
+    _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey']
+    _verboseMode = True
+
+    def testDOH3Post(self):
+        """
+        QUIC: Simple POST query
+        """
+        name = 'simple.post.doq.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.id = 0
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.id = 0
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, serverName=self._serverName, post=True)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.assertEqual(receivedResponse, response)