]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Gather Server Name Indication on QUIC (DoQ, DoH3) connections 15024/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 9 Jan 2025 14:36:49 +0000 (15:36 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 23 Jan 2025 09:38:05 +0000 (10:38 +0100)
pdns/dnsdistdist/doh3.cc
pdns/dnsdistdist/doh3.hh
pdns/dnsdistdist/doq-common.cc
pdns/dnsdistdist/doq-common.hh
pdns/dnsdistdist/doq.cc
pdns/dnsdistdist/doq.hh
pdns/dnsdistdist/m4/pdns_with_quiche.m4
regression-tests.dnsdist/doh3client.py
regression-tests.dnsdist/doqclient.py
regression-tests.dnsdist/test_SNI.py [new file with mode: 0644]

index 6ee0db58b994231e2171aea01172c3292b1d4a57..edd08934d1d1467103886d3dd6073e4317656917 100644 (file)
@@ -62,6 +62,14 @@ public:
   H3Connection& operator=(H3Connection&&) = default;
   ~H3Connection() = default;
 
+  std::shared_ptr<const std::string> getSNI()
+  {
+    if (!d_sni) {
+      d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
+    }
+    return d_sni;
+  }
+
   ComboAddress d_peer;
   ComboAddress d_localAddr;
   QuicheConnection d_conn;
@@ -71,6 +79,7 @@ public:
   std::unordered_map<uint64_t, dnsdist::doh3::h3_headers_t> d_headersBuffers;
   std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
   std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
+  std::shared_ptr<const std::string> d_sni{nullptr};
 };
 
 static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
@@ -566,6 +575,9 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
       ids.origFlags = *flags;
       return true;
     });
+    if (unit->sni) {
+      dnsQuestion.sni = *unit->sni;
+    }
     unit->ids.cs = &clientState;
 
     auto result = processQuery(dnsQuestion, downstream);
@@ -640,7 +652,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
   }
 }
 
-static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers)
+static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni, dnsdist::doh3::h3_headers_t&& headers)
 {
   try {
     auto unit = std::make_unique<DOH3Unit>(std::move(query));
@@ -650,6 +662,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con
     unit->ids.protocol = dnsdist::Protocol::DoH3;
     unit->serverConnID = serverConnID;
     unit->streamID = streamID;
+    unit->sni = sni;
     unit->headers = std::move(headers);
 
     processDOH3Query(std::move(unit));
@@ -751,7 +764,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
       return;
     }
     DEBUGLOG("Dispatching GET query");
-    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
+    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
     conn.d_streamBuffers.erase(streamID);
     conn.d_headersBuffers.erase(streamID);
     return;
@@ -816,7 +829,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
   }
 
   DEBUGLOG("Dispatching POST query");
-  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
+  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
   conn.d_headersBuffers.erase(streamID);
   conn.d_streamBuffers.erase(streamID);
 }
index 97a52a2c42022361e4388d3bc11f5d9ff995e4c3..9e021c6114dbad17acfec4890ddf05cb4712c7f4 100644 (file)
@@ -101,6 +101,7 @@ struct DOH3Unit
   PacketBuffer serverConnID;
   dnsdist::doh3::h3_headers_t headers;
   std::shared_ptr<DownstreamState> downstream{nullptr};
+  std::shared_ptr<const std::string> sni{nullptr};
   std::string d_contentTypeOut;
   DOH3ServerConfig* dsc{nullptr};
   uint64_t streamID{0};
index bb79ddc21849693f6a474ac40851c25bf081156b..ea6145476131947022ca085743cad79a84be079b 100644 (file)
@@ -325,6 +325,18 @@ bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, C
   return !buffer.empty();
 }
 
-};
+std::string getSNIFromQuicheConnection(const QuicheConnection& conn)
+{
+#if defined(HAVE_QUICHE_CONN_SERVER_NAME)
+  const uint8_t* sniPtr = nullptr;
+  size_t sniPtrSize = 0;
+  quiche_conn_server_name(conn.get(), &sniPtr, &sniPtrSize);
+  if (sniPtrSize > 0) {
+    return std::string(reinterpret_cast<const char*>(sniPtr), sniPtrSize);
+  }
+#endif /* HAVE_QUICHE_CONN_SERVER_NAME */
+  return {};
+}
+}
 
 #endif
index 9b04e4c83581460e5842aadaa32720e257cb4751..43c6bd9c55dbe9e197bc16e99946ce24d832bca3 100644 (file)
@@ -23,6 +23,7 @@
 
 #include <map>
 #include <memory>
+#include <string>
 
 #include "config.h"
 
@@ -97,7 +98,7 @@ void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, co
 void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer);
 void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP);
 bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr);
-
+std::string getSNIFromQuicheConnection(const QuicheConnection& conn);
 };
 
 #endif
index 661fe5c2b5f96885be1aa95a1e3876d3f22277f0..73536cffec456c748fd41a54d074b5da137fd52a 100644 (file)
@@ -61,6 +61,14 @@ public:
   Connection& operator=(Connection&&) = default;
   ~Connection() = default;
 
+  std::shared_ptr<const std::string> getSNI()
+  {
+    if (!d_sni) {
+      d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
+    }
+    return d_sni;
+  }
+
   ComboAddress d_peer;
   ComboAddress d_localAddr;
   QuicheConnection d_conn;
@@ -68,6 +76,7 @@ public:
 
   std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
   std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
+  std::shared_ptr<const std::string> d_sni{nullptr};
 };
 
 static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
@@ -472,6 +481,9 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
       ids.origFlags = *flags;
       return true;
     });
+    if (unit->sni) {
+      dnsQuestion.sni = *unit->sni;
+    }
     unit->ids.cs = &clientState;
 
     auto result = processQuery(dnsQuestion, downstream);
@@ -541,7 +553,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
   }
 }
 
-static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
+static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni)
 {
   try {
     auto unit = std::make_unique<DOQUnit>(std::move(query));
@@ -551,6 +563,7 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const
     unit->ids.protocol = dnsdist::Protocol::DoQ;
     unit->serverConnID = serverConnID;
     unit->streamID = streamID;
+    unit->sni = sni;
 
     processDOQQuery(std::move(unit));
   }
@@ -649,7 +662,7 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState
     return;
   }
   DEBUGLOG("Dispatching query");
-  doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
+  doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI());
   conn.d_streamBuffers.erase(streamID);
 }
 
index 258194177a6ea635299d507d6da79bee8d04c021..4b0d2dc550a6314b11a0b773850ee1cb65b56386 100644 (file)
@@ -84,6 +84,7 @@ struct DOQUnit
   PacketBuffer response;
   PacketBuffer serverConnID;
   std::shared_ptr<DownstreamState> downstream{nullptr};
+  std::shared_ptr<const std::string> sni{nullptr};
   DOQServerConfig* dsc{nullptr};
   uint64_t streamID{0};
   size_t proxyProtocolPayloadSize{0};
index 672fe0f79f225c0e4ec30648bf464ed05906e581..784c9cb869eec34d0855bf529dc01587e88a8cea 100644 (file)
@@ -21,6 +21,16 @@ AC_DEFUN([PDNS_WITH_QUICHE], [
           AC_DEFINE([HAVE_QUICHE], [1], [Define to 1 if you have quiche])
         ], [ : ])
       ])
+      AS_IF([test "x$HAVE_QUICHE" = "x1"], [
+        save_CFLAGS=$CFLAGS
+        save_LIBS=$LIBS
+        CFLAGS="$QUICHE_CFLAGS $CFLAGS"
+        LIBS="$QUICHE_LIBS $LIBS"
+        AC_CHECK_FUNCS([quiche_conn_server_name])
+        CFLAGS=$save_CFLAGS
+        LIBS=$save_LIBS
+
+      ])
     ])
   ])
   AM_CONDITIONAL([HAVE_QUICHE], [test "x$QUICHE_LIBS" != "x"])
index 953f5befa0abadffcc30f5dff9c25766761f9e21..e2e2bc107ad685dd4f63b4c52ba9d4bfb16db005 100644 (file)
@@ -228,7 +228,7 @@ async def async_h3_query(
 
 
 def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
-    configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
+    configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name=server_hostname)
     if verify:
         configuration.load_verify_locations(verify)
 
index 2f0272630f3b9399a1ae9e20305fa6bacd51450c..7fa416d237c533fad97a5f273e49d538a4a183ee 100644 (file)
@@ -88,7 +88,7 @@ class StreamResetError(Exception):
         super().__init__(message)
 
 def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
-    configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+    configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
     if verify:
         configuration.load_verify_locations(verify)
     (result, serial) = asyncio.run(
@@ -108,7 +108,7 @@ def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server
     return (result, serial)
 
 def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
-    configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+    configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
     if verify:
         configuration.load_verify_locations(verify)
     (result, _) = asyncio.run(
diff --git a/regression-tests.dnsdist/test_SNI.py b/regression-tests.dnsdist/test_SNI.py
new file mode 100644 (file)
index 0000000..ac2f221
--- /dev/null
@@ -0,0 +1,82 @@
+#!/usr/bin/env python
+import base64
+import dns
+import os
+import unittest
+import pycurl
+
+from dnsdisttests import DNSDistTest, pickAvailablePort
+
+class TestSNI(DNSDistTest):
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = pickAvailablePort()
+    _dohWithNGHTTP2ServerPort = pickAvailablePort()
+    _doqServerPort = pickAvailablePort()
+    _doh3ServerPort = pickAvailablePort()
+    _dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (_serverName, _dohWithNGHTTP2ServerPort))
+    _dohBaseURL = ("https://%s:%d/" % (_serverName, _doh3ServerPort))
+
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
+    addDOQLocal("127.0.0.1:%d", "%s", "%s")
+    addDOH3Local("127.0.0.1:%d", "%s", "%s")
+
+    function displaySNI(dq)
+      local sni = dq:getServerNameIndication()
+      if sni ~= '%s' then
+        return DNSAction.Spoof, '1.2.3.4'
+      end
+      return DNSAction.Allow
+    end
+    addAction(AllRule(), LuaAction(displaySNI))
+    """
+    _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName']
+
+    # enable these once Quiche > 0.22 is available, including https://github.com/cloudflare/quiche/pull/1895
+    @unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quicheare disabled")
+    def testServerNameIndicationWithQuiche(self):
+        name = 'simple.sni.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        for method in ["sendDOQQueryWrapper", "sendDOH3QueryWrapper"]:
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response, timeout=1)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertTrue(receivedResponse)
+            if method == 'sendDOQQueryWrapper':
+                # dnspython sets the ID to 0
+                receivedResponse.id = response.id
+            self.assertEqual(response, receivedResponse)
+
+    def testServerNameIndication(self):
+        name = 'simple.sni.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]:
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response, timeout=1)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(response, receivedResponse)