From 396bc0d448a04a982526d7d84bef7bf8ed9b73c8 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Fri, 4 Apr 2025 15:18:31 +0200 Subject: [PATCH] dnsdist: Add support for switching certificates based on SNI w/ OpenSSL We already supported this with GnuTLS, but OpenSSL does not make it easy: we need to keep a different `SSL_CTX` object for each certificate/key and change the `SSL_CTX` associated with an incoming connection to the correct one based on the Server Name Indication from the servername callback (actually OpenSSL devs advise to use the ClientHello callback instead when using a recent enough version of OpenSSL, but the SNI hostname is not available is not available at this point so we would have to parse it ourselves, which is a terrible idea, and the drawbacks are not clear. `nginx` has been getting away with it, so hopefully we will as well). One additional issue is that we still need to load certificates for the same name but different key types (RSA vs ECDSA, for example) in the same `SSL_CTX` context, which makes the code a bit convoluted. --- .../dnsdistdist/dnsdist-configuration-yaml.cc | 3 +- pdns/dnsdistdist/dnsdist-lua.cc | 12 +- pdns/dnsdistdist/doh.cc | 4 +- pdns/dnsdistdist/doq-common.cc | 4 +- pdns/libssl.cc | 230 ++++++++++++++++-- pdns/libssl.hh | 22 +- pdns/tcpiohandler.cc | 77 ++++-- regression-tests.dnsdist/.gitignore | 5 + regression-tests.dnsdist/Makefile | 14 ++ regression-tests.dnsdist/configServer2.conf | 20 ++ regression-tests.dnsdist/dnsdisttests.py | 24 +- regression-tests.dnsdist/test_SNI.py | 129 +++++++++- 12 files changed, 471 insertions(+), 73 deletions(-) create mode 100644 regression-tests.dnsdist/configServer2.conf diff --git a/pdns/dnsdistdist/dnsdist-configuration-yaml.cc b/pdns/dnsdistdist/dnsdist-configuration-yaml.cc index 7aea2c681e..edd9c3f788 100644 --- a/pdns/dnsdistdist/dnsdist-configuration-yaml.cc +++ b/pdns/dnsdistdist/dnsdist-configuration-yaml.cc @@ -228,8 +228,7 @@ static bool validateTLSConfiguration(const dnsdist::rust::settings::BindConfigur // we are asked to try to load the certificates so we can return a potential error // and properly ignore the frontend before actually launching it try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(tlsConfig); } catch (const std::runtime_error& e) { errlog("Ignoring %s frontend: '%s'", bind.protocol, e.what()); diff --git a/pdns/dnsdistdist/dnsdist-lua.cc b/pdns/dnsdistdist/dnsdist-lua.cc index 3504cbd6a4..ef2ac732ae 100644 --- a/pdns/dnsdistdist/dnsdist-lua.cc +++ b/pdns/dnsdistdist/dnsdist-lua.cc @@ -2248,8 +2248,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) // we are asked to try to load the certificates so we can return a potential error // and properly ignore the frontend before actually launching it try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig); } catch (const std::runtime_error& e) { errlog("Ignoring DoH frontend: '%s'", e.what()); @@ -2346,8 +2345,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) // we are asked to try to load the certificates so we can return a potential error // and properly ignore the frontend before actually launching it try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig); } catch (const std::runtime_error& e) { errlog("Ignoring DoH3 frontend: '%s'", e.what()); @@ -2423,8 +2421,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) // we are asked to try to load the certificates so we can return a potential error // and properly ignore the frontend before actually launching it try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig); } catch (const std::runtime_error& e) { errlog("Ignoring DoQ frontend: '%s'", e.what()); @@ -2776,8 +2773,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) // we are asked to try to load the certificates so we can return a potential error // and properly ignore the frontend before actually launching it try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(frontend->d_tlsConfig); } catch (const std::runtime_error& e) { errlog("Ignoring TLS frontend: '%s'", e.what()); diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index fc03c86430..3eb1bccd54 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -1482,7 +1482,7 @@ static void setupTLSContext(DOHAcceptContext& acceptCtx, tlsConfig.d_ciphers = DOH_DEFAULT_CIPHERS.data(); } - auto [ctx, warnings] = libssl_init_server_context(tlsConfig, acceptCtx.d_ocspResponses); + auto [ctx, warnings] = libssl_init_server_context_no_sni(tlsConfig, acceptCtx.d_ocspResponses); for (const auto& warning : warnings) { warnlog("%s", warning); } @@ -1504,7 +1504,7 @@ static void setupTLSContext(DOHAcceptContext& acceptCtx, } #endif /* DISABLE_OCSP_STAPLING */ - libssl_set_error_counters_callback(ctx, &counters); + libssl_set_error_counters_callback(*ctx.get(), &counters); if (!tlsConfig.d_keyLogFile.empty()) { acceptCtx.d_keyLogFile = libssl_set_key_log_file(ctx.get(), tlsConfig.d_keyLogFile); diff --git a/pdns/dnsdistdist/doq-common.cc b/pdns/dnsdistdist/doq-common.cc index ce2993dcd6..f4b58f7d48 100644 --- a/pdns/dnsdistdist/doq-common.cc +++ b/pdns/dnsdistdist/doq-common.cc @@ -215,12 +215,12 @@ void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHT for (const auto& pair : params.d_tlsConfig.d_certKeyPairs) { auto res = quiche_config_load_cert_chain_from_pem_file(config.get(), pair.d_cert.c_str()); if (res != 0) { - throw std::runtime_error("Error loading the server certificate: " + std::to_string(res)); + throw std::runtime_error("Error loading the server certificate from '" + pair.d_cert + "': " + std::to_string(res)); } if (pair.d_key) { res = quiche_config_load_priv_key_from_pem_file(config.get(), pair.d_key->c_str()); if (res != 0) { - throw std::runtime_error("Error loading the server key: " + std::to_string(res)); + throw std::runtime_error("Error loading the server key from '" + *(pair.d_key) + "': " + std::to_string(res)); } } } diff --git a/pdns/libssl.cc b/pdns/libssl.cc index 8b7b3e5de2..cd386b7f9d 100644 --- a/pdns/libssl.cc +++ b/pdns/libssl.cc @@ -28,6 +28,7 @@ #endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */ #include #include +#include #include #if OPENSSL_VERSION_MAJOR >= 3 @@ -353,10 +354,10 @@ static void libssl_info_callback(const SSL *ssl, int where, int /* ret */) } } -void libssl_set_error_counters_callback(std::unique_ptr& ctx, TLSErrorCounters* counters) +void libssl_set_error_counters_callback(SSL_CTX& ctx, TLSErrorCounters* counters) { - SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters); - SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback); + SSL_CTX_set_ex_data(&ctx, s_countersIndex, counters); + SSL_CTX_set_info_callback(&ctx, libssl_info_callback); } #ifndef DISABLE_OCSP_STAPLING @@ -518,12 +519,12 @@ bool libssl_generate_ocsp_response(const std::string& certFile, const std::strin #endif /* HAVE_OCSP_BASIC_SIGN */ #endif /* DISABLE_OCSP_STAPLING */ -static int libssl_get_last_key_type(std::unique_ptr& ctx) +static int libssl_get_last_key_type(SSL_CTX& ctx) { #ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY - auto pkey = SSL_CTX_get0_privatekey(ctx.get()); + auto pkey = SSL_CTX_get0_privatekey(&ctx); #else - auto temp = std::unique_ptr(SSL_new(ctx.get()), SSL_free); + auto temp = std::unique_ptr(SSL_new(&ctx), SSL_free); if (!temp) { return -1; } @@ -537,6 +538,61 @@ static int libssl_get_last_key_type(std::unique_ptr get_names_from_certificate(const X509* certificate) +{ + std::unordered_set result; + auto names = std::unique_ptr(static_cast(X509_get_ext_d2i(certificate, NID_subject_alt_name, nullptr, nullptr))); + if (names) { + for (int idx = 0; idx < sk_GENERAL_NAME_num(names.get()); idx++) { + const auto* name = sk_GENERAL_NAME_value(names.get(), idx); + if (name->type != GEN_DNS) { + /* ignore GEN_IPADD / name->d.iPAddress (raw IP address bytes), it cannot be used in SNI anyway */ + continue; + } + unsigned char* str = nullptr; + if (ASN1_STRING_to_UTF8(&str, name->d.dNSName) < 0) { + continue; + } + result.emplace(reinterpret_cast(str)); + OPENSSL_free(str); + } + } + + auto* name = X509_get_subject_name(certificate); + if (name != nullptr) { + ssize_t idx = -1; + while ((idx = X509_NAME_get_index_by_NID(name, NID_commonName, idx)) != -1) { + const auto* entry = X509_NAME_get_entry(name, idx); + const auto* value = X509_NAME_ENTRY_get_data(entry); + unsigned char* str = nullptr; + if (ASN1_STRING_to_UTF8(&str, value) < 0) { + continue; + } + result.emplace(reinterpret_cast(str)); + OPENSSL_free(str); + } + } + + return result; +} + +static std::unordered_set get_names_from_last_certificate(const SSL_CTX& ctx) +{ + const auto* cert = SSL_CTX_get0_certificate(&ctx); + if (cert == nullptr) { + return {}; + } + + return get_names_from_certificate(cert); +} + LibsslTLSVersion libssl_tls_version_from_string(const std::string& str) { if (str == "tls1.0") { @@ -570,7 +626,7 @@ const std::string& libssl_tls_version_to_string(LibsslTLSVersion version) return it->second; } -bool libssl_set_min_tls_version(std::unique_ptr& ctx, LibsslTLSVersion version) +static bool libssl_set_min_tls_version(SSL_CTX& ctx, LibsslTLSVersion version) { #if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version) /* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated @@ -597,7 +653,7 @@ bool libssl_set_min_tls_version(std::unique_ptr, std::vector> libssl_init_server_context(const TLSConfig& config, - [[maybe_unused]] std::map& ocspResponses) +static std::unique_ptr getNewServerContext(const TLSConfig& config, [[maybe_unused]] std::vector& warnings) { - std::vector warnings; auto ctx = std::unique_ptr(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free); if (!ctx) { @@ -952,7 +1006,7 @@ std::pair, std::vector, std::vector& names, const std::function& existingContextCallback) +{ + for (const auto& name : names) { + auto [existingEntry, inserted] = serverContext.d_sniMap.emplace(name, newContext); + if (!inserted) { + auto& existingContext = existingEntry->second; + existingContextCallback(existingContext); + } + else if (serverContext.d_sniMap.size() == 1) { + serverContext.d_defaultContext = newContext; + } + } +} + +std::pair, std::vector> libssl_init_server_context_no_sni(const TLSConfig& config, + [[maybe_unused]] std::map& ocspResponses) +{ + std::vector warnings; + auto ctx = getNewServerContext(config, warnings); + std::vector keyTypes; /* load certificate and private key */ for (const auto& pair : config.d_certKeyPairs) { @@ -1055,12 +1132,13 @@ std::pair, std::vector, std::vector> libssl_init_server_context(const TLSConfig& config) +{ + std::vector warnings; + pdns::libssl::ServerContext serverContext; + + std::vector keyTypes; + /* load certificate and private key */ + for (const auto& pair : config.d_certKeyPairs) { + auto uniqueCtx = getNewServerContext(config, warnings); + auto ctx = std::shared_ptr(uniqueCtx.release(), SSL_CTX_free); + if (!pair.d_key) { +#if defined(HAVE_SSL_CTX_USE_CERT_AND_KEY) + // If no separate key is given, treat it as a pkcs12 file + auto filePtr = pdns::UniqueFilePtr(fopen(pair.d_cert.c_str(), "r")); + if (!filePtr) { + throw std::runtime_error("Unable to open file " + pair.d_cert); + } + auto p12 = std::unique_ptr(d2i_PKCS12_fp(filePtr.get(), nullptr), PKCS12_free); + if (!p12) { + throw std::runtime_error("Unable to open PKCS12 file " + pair.d_cert); + } + EVP_PKEY *keyptr = nullptr; + X509 *certptr = nullptr; + STACK_OF(X509) *captr = nullptr; + if (!PKCS12_parse(p12.get(), (pair.d_password ? pair.d_password->c_str() : nullptr), &keyptr, &certptr, &captr)) { +#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 + bool failed = true; + /* we might be opening a PKCS12 file that uses RC2 CBC or 3DES CBC which, since OpenSSL 3.0.0, requires loading the legacy provider */ + auto libCtx = OSSL_LIB_CTX_get0_global_default(); + /* check whether the legacy provider is already loaded */ + if (!OSSL_PROVIDER_available(libCtx, "legacy")) { + /* it's not */ + auto provider = OSSL_PROVIDER_load(libCtx, "legacy"); + if (provider != nullptr) { + if (PKCS12_parse(p12.get(), (pair.d_password ? pair.d_password->c_str() : nullptr), &keyptr, &certptr, &captr)) { + failed = false; + } + /* we do not want to keep that provider around after that */ + OSSL_PROVIDER_unload(provider); + } + } + if (failed) { +#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */ + ERR_print_errors_fp(stderr); + throw std::runtime_error("An error occured while parsing PKCS12 file " + pair.d_cert); +#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 + } +#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */ + } + auto key = std::unique_ptr(keyptr, EVP_PKEY_free); + auto cert = std::unique_ptr(certptr, X509_free); + auto ca = std::unique_ptr(captr, [](STACK_OF(X509)* st){ sk_X509_free(st); }); + + auto addCertificateAndKey = [&pair, &key, &cert, &ca](std::shared_ptr& tlsContext) { + if (SSL_CTX_use_cert_and_key(tlsContext.get(), cert.get(), key.get(), ca.get(), 1) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("An error occurred while trying to load the TLS certificate and key from PKCS12 file " + pair.d_cert); + } + }; + + addCertificateAndKey(ctx); + auto names = get_names_from_last_certificate(*ctx); + mergeNewCertificateAndKey(serverContext, ctx, names, addCertificateAndKey); +#else + throw std::runtime_error("PKCS12 files are not supported by your openssl version"); +#endif /* HAVE_SSL_CTX_USE_CERT_AND_KEY */ + } else { + auto addCertificateAndKey = [&pair](std::shared_ptr& tlsContext) { + if (SSL_CTX_use_certificate_chain_file(tlsContext.get(), pair.d_cert.c_str()) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.d_cert); + } + if (SSL_CTX_use_PrivateKey_file(tlsContext.get(), pair.d_key->c_str(), SSL_FILETYPE_PEM) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.d_key.value()); + } + }; + + addCertificateAndKey(ctx); + auto names = get_names_from_last_certificate(*ctx); + mergeNewCertificateAndKey(serverContext, ctx, names, addCertificateAndKey); + } + + if (SSL_CTX_check_private_key(ctx.get()) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("The key from '" + pair.d_key.value() + "' does not match the certificate from '" + pair.d_cert + "'"); + } + /* store the type of the new key, we might need it later to select the right OCSP stapling response */ + auto keyType = libssl_get_last_key_type(*ctx.get()); + if (keyType < 0) { + throw std::runtime_error("The key from '" + pair.d_key.value() + "' has an unknown type"); + } + keyTypes.push_back(keyType); + } + +#ifndef DISABLE_OCSP_STAPLING + if (!config.d_ocspFiles.empty()) { + try { + serverContext.d_ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, std::move(keyTypes), warnings); + } + catch(const std::exception& e) { + throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what())); + } + } +#endif /* DISABLE_OCSP_STAPLING */ + + for (auto& entry : serverContext.d_sniMap) { + auto& ctx = entry.second; + if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) { + throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers); + } + +#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES + if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) { + throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13); + } +#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */ + } + + return {std::move(serverContext), std::move(warnings)}; +} + #ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK static void libssl_key_log_file_callback(const SSL* ssl, const char* line) { diff --git a/pdns/libssl.hh b/pdns/libssl.hh index 3d3e5a8eba..fd545019e4 100644 --- a/pdns/libssl.hh +++ b/pdns/libssl.hh @@ -147,16 +147,30 @@ bool libssl_generate_ocsp_response(const std::string& certFile, const std::strin #endif #endif /* DISABLE_OCSP_STAPLING */ -void libssl_set_error_counters_callback(std::unique_ptr& ctx, TLSErrorCounters* counters); +void libssl_set_error_counters_callback(SSL_CTX& ctx, TLSErrorCounters* counters); LibsslTLSVersion libssl_tls_version_from_string(const std::string& str); const std::string& libssl_tls_version_to_string(LibsslTLSVersion version); -bool libssl_set_min_tls_version(std::unique_ptr& ctx, LibsslTLSVersion version); + + +namespace pdns::libssl { +class ServerContext +{ +public: + using SharedContext = std::shared_ptr; + using SNIToContextMap = std::map>; + + SharedContext d_defaultContext; + SNIToContextMap d_sniMap; + std::map d_ocspResponses; +}; +} /* return the created context, and a list of warning messages for issues not severe enough to trigger raising an exception, like failing to load an OCSP response file */ -std::pair, std::vector> libssl_init_server_context(const TLSConfig& config, - std::map& ocspResponses); +std::pair, std::vector> libssl_init_server_context_no_sni(const TLSConfig& config, + std::map& ocspResponses); +std::pair> libssl_init_server_context(const TLSConfig& config); pdns::UniqueFilePtr libssl_set_key_log_file(SSL_CTX* ctx, const std::string& logFile); diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 9379576b60..60cb189258 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -57,6 +57,7 @@ bool shouldDoVerboseLogging() #include "libssl.hh" +static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* arg); class OpenSSLFrontendContext { @@ -65,11 +66,16 @@ public: { registerOpenSSLUser(); - auto [ctx, warnings] = libssl_init_server_context(tlsConfig, d_ocspResponses); + auto [ctx, warnings] = libssl_init_server_context(tlsConfig); for (const auto& warning : warnings) { warnlog("%s", warning); } - d_tlsCtx = std::move(ctx); + d_ocspResponses = std::move(ctx.d_ocspResponses); + d_tlsCtx = std::move(ctx.d_defaultContext); + d_sniMap = std::move(ctx.d_sniMap); + for (auto& entry : d_sniMap) { + SSL_CTX_set_tlsext_servername_callback(entry.second.get(), &sni_server_name_callback); + } if (!d_tlsCtx) { ERR_print_errors_fp(stderr); @@ -86,10 +92,38 @@ public: OpenSSLTLSTicketKeysRing d_ticketKeys; std::map d_ocspResponses; - std::unique_ptr d_tlsCtx{nullptr, SSL_CTX_free}; + pdns::libssl::ServerContext::SNIToContextMap d_sniMap; + std::shared_ptr d_tlsCtx{nullptr}; pdns::UniqueFilePtr d_keyLogFile{nullptr}; }; + +static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* /* arg */) +{ + const auto* serverName = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (serverName == nullptr) { + return SSL_TLSEXT_ERR_NOACK; + } + auto* frontendCtx = reinterpret_cast(libssl_get_ticket_key_callback_data(ssl)); + if (frontendCtx == nullptr) { + return SSL_TLSEXT_ERR_OK; + } + + auto serverNameView = std::string_view(serverName); + + auto it = frontendCtx->d_sniMap.find(serverNameView); + if (it == frontendCtx->d_sniMap.end()) { + /* keep the default certificate */ + return SSL_TLSEXT_ERR_OK; + } + + /* if it fails there is nothing we can do, + let's hope OpenSSL will fallback to the existing, + default certificate*/ + SSL_set_SSL_CTX(ssl, it->second.get()); + return SSL_TLSEXT_ERR_OK; +} + class OpenSSLSession : public TLSSession { public: @@ -649,33 +683,36 @@ public: d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay; - if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) { - /* use our own ticket keys handler so we can rotate them */ + for (auto& entry : d_feContext->d_sniMap) { + auto* ctx = entry.second.get(); + if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) { + /* use our own ticket keys handler so we can rotate them */ #if OPENSSL_VERSION_MAJOR >= 3 - SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); + SSL_CTX_set_tlsext_ticket_key_evp_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb); #else - SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); + SSL_CTX_set_tlsext_ticket_key_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb); #endif - libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get()); - } + libssl_set_ticket_key_callback_data(ctx, d_feContext.get()); + } #ifndef DISABLE_OCSP_STAPLING - if (!d_feContext->d_ocspResponses.empty()) { - SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); - SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses); - } + if (!d_feContext->d_ocspResponses.empty()) { + SSL_CTX_set_tlsext_status_cb(ctx, &OpenSSLTLSIOCtx::ocspStaplingCb); + SSL_CTX_set_tlsext_status_arg(ctx, &d_feContext->d_ocspResponses); + } #endif /* DISABLE_OCSP_STAPLING */ - if (frontend.d_tlsConfig.d_readAhead) { - SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1); - } + if (frontend.d_tlsConfig.d_readAhead) { + SSL_CTX_set_read_ahead(ctx, 1); + } - libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &frontend.d_tlsCounters); + libssl_set_error_counters_callback(*ctx, &frontend.d_tlsCounters); - libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this); + libssl_set_alpn_select_callback(ctx, alpnServerSelectCallback, this); - if (!frontend.d_tlsConfig.d_keyLogFile.empty()) { - d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx.get(), frontend.d_tlsConfig.d_keyLogFile); + if (!frontend.d_tlsConfig.d_keyLogFile.empty()) { + d_feContext->d_keyLogFile = libssl_set_key_log_file(ctx, frontend.d_tlsConfig.d_keyLogFile); + } } try { diff --git a/regression-tests.dnsdist/.gitignore b/regression-tests.dnsdist/.gitignore index 0b347c4993..f5c450fbbd 100644 --- a/regression-tests.dnsdist/.gitignore +++ b/regression-tests.dnsdist/.gitignore @@ -14,7 +14,12 @@ /server.csr /server.key /server.pem +/server2.chain +/server2.csr +/server2.key +/server2.pem /server.p12 +/server-ec.* /server-doq.* /server-doh3.* /server-ocsp.chain diff --git a/regression-tests.dnsdist/Makefile b/regression-tests.dnsdist/Makefile index 84286d7a4a..e851c8c149 100644 --- a/regression-tests.dnsdist/Makefile +++ b/regression-tests.dnsdist/Makefile @@ -13,3 +13,17 @@ certs: cat server.pem ca.pem > server.chain # Generate a password-protected PKCS12 file openssl pkcs12 -export -passout pass:passw0rd -clcerts -in server.pem -CAfile ca.pem -inkey server.key -out server.p12 + # Generate a second server certificate request + openssl req -new -newkey rsa:2048 -nodes -keyout server2.key -out server2.csr -config configServer2.conf + # Sign the server cert + openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server2.csr -out server2.pem -extfile configServer2.conf -extensions v3_req + # Generate a chain + cat server2.pem ca.pem > server2.chain + # Generate a ECDSA key with P-256 + openssl ecparam -name secp256r1 -genkey -noout -out server-ec.key + # Generate a new server certificate request with the ECDSA key + openssl req -new -key server-ec.key -nodes -out server-ec.csr -config configServer.conf + # Sign the server cert + openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server-ec.csr -out server-ec.pem -extfile configServer.conf -extensions v3_req + # Generate a chain + cat server-ec.pem ca.pem > server-ec.chain diff --git a/regression-tests.dnsdist/configServer2.conf b/regression-tests.dnsdist/configServer2.conf new file mode 100644 index 0000000000..1208d580ac --- /dev/null +++ b/regression-tests.dnsdist/configServer2.conf @@ -0,0 +1,20 @@ +[req] +default_bits = 2048 +encrypt_key = no +prompt = no +distinguished_name = server_distinguished_name +req_extensions = v3_req + +[server_distinguished_name] +CN = tls2.tests.dnsdist.org +OU = PowerDNS.com BV +countryName = NL + +[v3_req] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +subjectAltName = @alt_names + +[alt_names] +DNS.1 = tls2.tests.dnsdist.org +IP.2 = 192.0.2.1 diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 53c97b04b4..c14ab88310 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -1129,23 +1129,23 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): cls._response_headers = response_headers.getvalue() return (receivedQuery, message) - def sendDOHQueryWrapper(self, query, response, useQueue=True, timeout=2): - return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout) + def sendDOHQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None): + return self.sendDOHQuery(self._dohServerPort, self._serverName if not serverName else serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout) - def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True, timeout=2): - return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout) + def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None): + return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName if not serverName else serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout) - def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True, timeout=2): - return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout) + def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None): + return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName if not serverName else serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout) - def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2): - return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout) + def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None): + return self.sendDOTQuery(self._tlsServerPort, self._serverName if not serverName else serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout) - def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2): - return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, timeout=timeout) + def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None): + return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout) - def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2): - return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, timeout=timeout) + def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None): + return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout) @classmethod def getDOQConnection(cls, port, caFile=None, source=None, source_port=0): diff --git a/regression-tests.dnsdist/test_SNI.py b/regression-tests.dnsdist/test_SNI.py index 1e93cc6ce4..31ac684d47 100644 --- a/regression-tests.dnsdist/test_SNI.py +++ b/regression-tests.dnsdist/test_SNI.py @@ -4,13 +4,20 @@ import dns import os import unittest import pycurl +import ssl from dnsdisttests import DNSDistTest, pickAvailablePort class TestSNI(DNSDistTest): _serverKey = 'server.key' _serverCert = 'server.chain' + _serverKeyEC = 'server-ec.key' + _serverCertEC = 'server-ec.chain' + _serverKey2 = 'server2.key' + _serverCert2 = 'server2.chain' _serverName = 'tls.tests.dnsdist.org' + _serverName2 = 'tls2.tests.dnsdist.org' + _serverName3 = 'unknown.tests.dnsdist.org' _caCert = 'ca.pem' _tlsServerPort = pickAvailablePort() _dohWithNGHTTP2ServerPort = pickAvailablePort() @@ -22,21 +29,37 @@ class TestSNI(DNSDistTest): _config_template = """ newServer{address="127.0.0.1:%d"} - addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" }) - addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"}) - addDOQLocal("127.0.0.1:%d", "%s", "%s") - addDOH3Local("127.0.0.1:%d", "%s", "%s") + local certs = {"%s", "%s", "%s"} + local keys = {"%s", "%s", "%s"} + local single_cert = "%s" + local single_key = "%s" + addTLSLocal("127.0.0.1:%d", certs, keys, { provider="openssl" }) + addDOHLocal("127.0.0.1:%d", certs, keys, {"/"}, {library="nghttp2"}) + addDOQLocal("127.0.0.1:%d", single_cert, single_key) + addDOH3Local("127.0.0.1:%d", single_cert, single_key) - function displaySNI(dq) + function checkSNI(dq) local sni = dq:getServerNameIndication() - if sni ~= '%s' then + if tostring(dq.qname) == 'simple.sni.tests.powerdns.com.' and sni ~= '%s' then return DNSAction.Spoof, '1.2.3.4' end + if tostring(dq.qname) == 'name2.sni.tests.powerdns.com.' and sni ~= '%s' then + return DNSAction.Spoof, '2.3.4.5' + end + if tostring(dq.qname) == 'unknown.sni.tests.powerdns.com.' and sni ~= '%s' then + return DNSAction.Spoof, '3.4.5.6' + end + if tostring(dq.qname) == 'ecdsa.sni.tests.powerdns.com.' and sni ~= '%s' then + return DNSAction.Spoof, '4.5.6.7' + end + if tostring(dq.qname) == 'rsa.sni.tests.powerdns.com.' and sni ~= '%s' then + return DNSAction.Spoof, '4.5.6.7' + end return DNSAction.Allow end - addAction(AllRule(), LuaAction(displaySNI)) + addAction(AllRule(), LuaAction(checkSNI)) """ - _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName'] + _config_params = ['_testServerPort', '_serverCert', '_serverCertEC', '_serverCert2', '_serverKey', '_serverKeyEC', '_serverKey2', '_serverCert', '_serverKey', '_tlsServerPort', '_dohWithNGHTTP2ServerPort', '_doqServerPort', '_doh3ServerPort', '_serverName', '_serverName2', '_serverName3', '_serverName', '_serverName'] @unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quiche are disabled") def testServerNameIndicationWithQuiche(self): @@ -79,3 +102,93 @@ class TestSNI(DNSDistTest): self.assertEqual(query, receivedQuery) self.assertTrue(receivedResponse) self.assertEqual(response, receivedResponse) + + # check second certificate + name = 'name2.sni.tests.powerdns.com.' + self._dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (self._serverName2, self._dohWithNGHTTP2ServerPort)) + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]: + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response, timeout=1, serverName=self._serverName2) + self.assertTrue(receivedQuery) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertTrue(receivedResponse) + self.assertEqual(response, receivedResponse) + + # check SNI for an unkown name, we should get the first certificate + name = 'unknown.sni.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + sslctx = ssl.create_default_context(cafile=self._caCert) + sslctx.check_hostname = False + if hasattr(sslctx, 'set_alpn_protocols'): + sslctx.set_alpn_protocols(self._serverName3) + + conn = self.openTLSConnection(self._tlsServerPort, self._serverName3, self._caCert, timeout=1, sslctx=sslctx) + self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1) + (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1) + receivedQuery.id = query.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + + cert = conn.getpeercert() + subject = cert['subject'] + altNames = cert['subjectAltName'] + self.assertEqual(dict(subject[0])['commonName'], 'tls.tests.dnsdist.org') + self.assertEqual(dict(subject[1])['organizationalUnitName'], 'PowerDNS.com BV') + names = [] + for entry in altNames: + names.append(entry[1]) + self.assertEqual(names, ['tls.tests.dnsdist.org', 'powerdns.com', '127.0.0.1']) + + # check that we provide the correct RSA/ECDSA certificate when requested + for algo in ['rsa', 'ecdsa']: + name = algo + '.sni.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + sslctx = ssl.create_default_context(cafile=self._caCert) + if hasattr(sslctx, 'set_alpn_protocols'): + sslctx.set_alpn_protocols(self._serverName) + # disable TLS 1.3 because configuring the signature algorithm is not supported by Python yet + sslctx.maximum_version = ssl.TLSVersion.TLSv1_2 + # explicitly request authentication via RSA or ECDSA + sslctx.set_ciphers('a' + algo.upper()) + + conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert, timeout=1, sslctx=sslctx) + self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1) + (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1) + receivedQuery.id = query.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + + cert = conn.getpeercert() + subject = cert['subject'] + altNames = cert['subjectAltName'] + self.assertEqual(dict(subject[0])['commonName'], 'tls.tests.dnsdist.org') + self.assertEqual(dict(subject[1])['organizationalUnitName'], 'PowerDNS.com BV') + names = [] + for entry in altNames: + names.append(entry[1]) + self.assertEqual(names, ['tls.tests.dnsdist.org', 'powerdns.com', '127.0.0.1']) -- 2.47.2