H3Connection& operator=(H3Connection&&) = default;
~H3Connection() = default;
+ std::shared_ptr<const std::string> getSNI()
+ {
+ if (!d_sni) {
+ d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
+ }
+ return d_sni;
+ }
+
ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
std::unordered_map<uint64_t, dnsdist::doh3::h3_headers_t> d_headersBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
+ std::shared_ptr<const std::string> d_sni{nullptr};
};
static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
ids.origFlags = *flags;
return true;
});
+ if (unit->sni) {
+ dnsQuestion.sni = *unit->sni;
+ }
unit->ids.cs = &clientState;
auto result = processQuery(dnsQuestion, downstream);
}
}
-static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers)
+static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni, dnsdist::doh3::h3_headers_t&& headers)
{
try {
auto unit = std::make_unique<DOH3Unit>(std::move(query));
unit->ids.protocol = dnsdist::Protocol::DoH3;
unit->serverConnID = serverConnID;
unit->streamID = streamID;
+ unit->sni = sni;
unit->headers = std::move(headers);
processDOH3Query(std::move(unit));
return;
}
DEBUGLOG("Dispatching GET query");
- doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
+ doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
conn.d_streamBuffers.erase(streamID);
conn.d_headersBuffers.erase(streamID);
return;
}
DEBUGLOG("Dispatching POST query");
- doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
+ doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
conn.d_headersBuffers.erase(streamID);
conn.d_streamBuffers.erase(streamID);
}
PacketBuffer serverConnID;
dnsdist::doh3::h3_headers_t headers;
std::shared_ptr<DownstreamState> downstream{nullptr};
+ std::shared_ptr<const std::string> sni{nullptr};
std::string d_contentTypeOut;
DOH3ServerConfig* dsc{nullptr};
uint64_t streamID{0};
return !buffer.empty();
}
-};
+std::string getSNIFromQuicheConnection(const QuicheConnection& conn)
+{
+#if defined(HAVE_QUICHE_CONN_SERVER_NAME)
+ const uint8_t* sniPtr = nullptr;
+ size_t sniPtrSize = 0;
+ quiche_conn_server_name(conn.get(), &sniPtr, &sniPtrSize);
+ if (sniPtrSize > 0) {
+ return std::string(reinterpret_cast<const char*>(sniPtr), sniPtrSize);
+ }
+#endif /* HAVE_QUICHE_CONN_SERVER_NAME */
+ return {};
+}
+}
#endif
#include <map>
#include <memory>
+#include <string>
#include "config.h"
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer);
void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP);
bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr);
-
+std::string getSNIFromQuicheConnection(const QuicheConnection& conn);
};
#endif
Connection& operator=(Connection&&) = default;
~Connection() = default;
+ std::shared_ptr<const std::string> getSNI()
+ {
+ if (!d_sni) {
+ d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
+ }
+ return d_sni;
+ }
+
ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
+ std::shared_ptr<const std::string> d_sni{nullptr};
};
static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
ids.origFlags = *flags;
return true;
});
+ if (unit->sni) {
+ dnsQuestion.sni = *unit->sni;
+ }
unit->ids.cs = &clientState;
auto result = processQuery(dnsQuestion, downstream);
}
}
-static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
+static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni)
{
try {
auto unit = std::make_unique<DOQUnit>(std::move(query));
unit->ids.protocol = dnsdist::Protocol::DoQ;
unit->serverConnID = serverConnID;
unit->streamID = streamID;
+ unit->sni = sni;
processDOQQuery(std::move(unit));
}
return;
}
DEBUGLOG("Dispatching query");
- doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
+ doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI());
conn.d_streamBuffers.erase(streamID);
}
PacketBuffer response;
PacketBuffer serverConnID;
std::shared_ptr<DownstreamState> downstream{nullptr};
+ std::shared_ptr<const std::string> sni{nullptr};
DOQServerConfig* dsc{nullptr};
uint64_t streamID{0};
size_t proxyProtocolPayloadSize{0};
AC_DEFINE([HAVE_QUICHE], [1], [Define to 1 if you have quiche])
], [ : ])
])
+ AS_IF([test "x$HAVE_QUICHE" = "x1"], [
+ save_CFLAGS=$CFLAGS
+ save_LIBS=$LIBS
+ CFLAGS="$QUICHE_CFLAGS $CFLAGS"
+ LIBS="$QUICHE_LIBS $LIBS"
+ AC_CHECK_FUNCS([quiche_conn_server_name])
+ CFLAGS=$save_CFLAGS
+ LIBS=$save_LIBS
+
+ ])
])
])
AM_CONDITIONAL([HAVE_QUICHE], [test "x$QUICHE_LIBS" != "x"])
def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
- configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
+ configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
super().__init__(message)
def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
- configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+ configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
(result, serial) = asyncio.run(
return (result, serial)
def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
- configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+ configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
(result, _) = asyncio.run(
--- /dev/null
+#!/usr/bin/env python
+import base64
+import dns
+import os
+import unittest
+import pycurl
+
+from dnsdisttests import DNSDistTest, pickAvailablePort
+
+class TestSNI(DNSDistTest):
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _tlsServerPort = pickAvailablePort()
+ _dohWithNGHTTP2ServerPort = pickAvailablePort()
+ _doqServerPort = pickAvailablePort()
+ _doh3ServerPort = pickAvailablePort()
+ _dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (_serverName, _dohWithNGHTTP2ServerPort))
+ _dohBaseURL = ("https://%s:%d/" % (_serverName, _doh3ServerPort))
+
+ _config_template = """
+ newServer{address="127.0.0.1:%d"}
+
+ addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
+ addDOQLocal("127.0.0.1:%d", "%s", "%s")
+ addDOH3Local("127.0.0.1:%d", "%s", "%s")
+
+ function displaySNI(dq)
+ local sni = dq:getServerNameIndication()
+ if sni ~= '%s' then
+ return DNSAction.Spoof, '1.2.3.4'
+ end
+ return DNSAction.Allow
+ end
+ addAction(AllRule(), LuaAction(displaySNI))
+ """
+ _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName']
+
+ # enable these once Quiche > 0.22 is available, including https://github.com/cloudflare/quiche/pull/1895
+ @unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quicheare disabled")
+ def testServerNameIndicationWithQuiche(self):
+ name = 'simple.sni.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+ for method in ["sendDOQQueryWrapper", "sendDOH3QueryWrapper"]:
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response, timeout=1)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertTrue(receivedResponse)
+ if method == 'sendDOQQueryWrapper':
+ # dnspython sets the ID to 0
+ receivedResponse.id = response.id
+ self.assertEqual(response, receivedResponse)
+
+ def testServerNameIndication(self):
+ name = 'simple.sni.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+ for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]:
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response, timeout=1)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertTrue(receivedResponse)
+ self.assertEqual(response, receivedResponse)