// we are asked to try to load the certificates so we can return a potential error
// and properly ignore the frontend before actually launching it
try {
- std::map<int, std::string> ocspResponses = {};
- auto ctx = libssl_init_server_context(tlsConfig, ocspResponses);
+ auto ctx = libssl_init_server_context(tlsConfig);
}
catch (const std::runtime_error& e) {
errlog("Ignoring %s frontend: '%s'", bind.protocol, e.what());
// we are asked to try to load the certificates so we can return a potential error
// and properly ignore the frontend before actually launching it
try {
- std::map<int, std::string> ocspResponses = {};
- auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig, ocspResponses);
+ auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig);
}
catch (const std::runtime_error& e) {
errlog("Ignoring DoH frontend: '%s'", e.what());
// we are asked to try to load the certificates so we can return a potential error
// and properly ignore the frontend before actually launching it
try {
- std::map<int, std::string> ocspResponses = {};
- auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig, ocspResponses);
+ auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig);
}
catch (const std::runtime_error& e) {
errlog("Ignoring DoH3 frontend: '%s'", e.what());
// we are asked to try to load the certificates so we can return a potential error
// and properly ignore the frontend before actually launching it
try {
- std::map<int, std::string> ocspResponses = {};
- auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig, ocspResponses);
+ auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig);
}
catch (const std::runtime_error& e) {
errlog("Ignoring DoQ frontend: '%s'", e.what());
// we are asked to try to load the certificates so we can return a potential error
// and properly ignore the frontend before actually launching it
try {
- std::map<int, std::string> ocspResponses = {};
- auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
+ auto ctx = libssl_init_server_context(frontend->d_tlsConfig);
}
catch (const std::runtime_error& e) {
errlog("Ignoring TLS frontend: '%s'", e.what());
tlsConfig.d_ciphers = DOH_DEFAULT_CIPHERS.data();
}
- auto [ctx, warnings] = libssl_init_server_context(tlsConfig, acceptCtx.d_ocspResponses);
+ auto [ctx, warnings] = libssl_init_server_context_no_sni(tlsConfig, acceptCtx.d_ocspResponses);
for (const auto& warning : warnings) {
warnlog("%s", warning);
}
}
#endif /* DISABLE_OCSP_STAPLING */
- libssl_set_error_counters_callback(ctx, &counters);
+ libssl_set_error_counters_callback(*ctx.get(), &counters);
if (!tlsConfig.d_keyLogFile.empty()) {
acceptCtx.d_keyLogFile = libssl_set_key_log_file(ctx.get(), tlsConfig.d_keyLogFile);
for (const auto& pair : params.d_tlsConfig.d_certKeyPairs) {
auto res = quiche_config_load_cert_chain_from_pem_file(config.get(), pair.d_cert.c_str());
if (res != 0) {
- throw std::runtime_error("Error loading the server certificate: " + std::to_string(res));
+ throw std::runtime_error("Error loading the server certificate from '" + pair.d_cert + "': " + std::to_string(res));
}
if (pair.d_key) {
res = quiche_config_load_priv_key_from_pem_file(config.get(), pair.d_key->c_str());
if (res != 0) {
- throw std::runtime_error("Error loading the server key: " + std::to_string(res));
+ throw std::runtime_error("Error loading the server key from '" + *(pair.d_key) + "': " + std::to_string(res));
}
}
}
#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */
#include <openssl/rand.h>
#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
#include <fcntl.h>
#if OPENSSL_VERSION_MAJOR >= 3
}
}
-void libssl_set_error_counters_callback(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, TLSErrorCounters* counters)
+void libssl_set_error_counters_callback(SSL_CTX& ctx, TLSErrorCounters* counters)
{
- SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters);
- SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback);
+ SSL_CTX_set_ex_data(&ctx, s_countersIndex, counters);
+ SSL_CTX_set_info_callback(&ctx, libssl_info_callback);
}
#ifndef DISABLE_OCSP_STAPLING
#endif /* HAVE_OCSP_BASIC_SIGN */
#endif /* DISABLE_OCSP_STAPLING */
-static int libssl_get_last_key_type(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx)
+static int libssl_get_last_key_type(SSL_CTX& ctx)
{
#ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY
- auto pkey = SSL_CTX_get0_privatekey(ctx.get());
+ auto pkey = SSL_CTX_get0_privatekey(&ctx);
#else
- auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(ctx.get()), SSL_free);
+ auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(&ctx), SSL_free);
if (!temp) {
return -1;
}
return EVP_PKEY_base_id(pkey);
}
+struct StackOfNamesDeleter
+{
+ void operator()(STACK_OF(GENERAL_NAME)* ptr) const noexcept {
+ sk_GENERAL_NAME_pop_free(ptr, GENERAL_NAME_free);
+ }
+};
+
+static std::unordered_set<std::string> get_names_from_certificate(const X509* certificate)
+{
+ std::unordered_set<std::string> result;
+ auto names = std::unique_ptr<STACK_OF(GENERAL_NAME), StackOfNamesDeleter>(static_cast<STACK_OF(GENERAL_NAME)*>(X509_get_ext_d2i(certificate, NID_subject_alt_name, nullptr, nullptr)));
+ if (names) {
+ for (int idx = 0; idx < sk_GENERAL_NAME_num(names.get()); idx++) {
+ const auto* name = sk_GENERAL_NAME_value(names.get(), idx);
+ if (name->type != GEN_DNS) {
+ /* ignore GEN_IPADD / name->d.iPAddress (raw IP address bytes), it cannot be used in SNI anyway */
+ continue;
+ }
+ unsigned char* str = nullptr;
+ if (ASN1_STRING_to_UTF8(&str, name->d.dNSName) < 0) {
+ continue;
+ }
+ result.emplace(reinterpret_cast<const char*>(str));
+ OPENSSL_free(str);
+ }
+ }
+
+ auto* name = X509_get_subject_name(certificate);
+ if (name != nullptr) {
+ ssize_t idx = -1;
+ while ((idx = X509_NAME_get_index_by_NID(name, NID_commonName, idx)) != -1) {
+ const auto* entry = X509_NAME_get_entry(name, idx);
+ const auto* value = X509_NAME_ENTRY_get_data(entry);
+ unsigned char* str = nullptr;
+ if (ASN1_STRING_to_UTF8(&str, value) < 0) {
+ continue;
+ }
+ result.emplace(reinterpret_cast<const char*>(str));
+ OPENSSL_free(str);
+ }
+ }
+
+ return result;
+}
+
+static std::unordered_set<std::string> get_names_from_last_certificate(const SSL_CTX& ctx)
+{
+ const auto* cert = SSL_CTX_get0_certificate(&ctx);
+ if (cert == nullptr) {
+ return {};
+ }
+
+ return get_names_from_certificate(cert);
+}
+
LibsslTLSVersion libssl_tls_version_from_string(const std::string& str)
{
if (str == "tls1.0") {
return it->second;
}
-bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, LibsslTLSVersion version)
+static bool libssl_set_min_tls_version(SSL_CTX& ctx, LibsslTLSVersion version)
{
#if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version)
/* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated
return false;
}
- if (SSL_CTX_set_min_proto_version(ctx.get(), vers) != 1) {
+ if (SSL_CTX_set_min_proto_version(&ctx, vers) != 1) {
return false;
}
return true;
return false;
}
- long options = SSL_CTX_get_options(ctx.get());
- SSL_CTX_set_options(ctx.get(), options | vers);
+ long options = SSL_CTX_get_options(&ctx);
+ SSL_CTX_set_options(&ctx, options | vers);
return true;
#endif
}
return true;
}
-std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config,
- [[maybe_unused]] std::map<int, std::string>& ocspResponses)
+static std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> getNewServerContext(const TLSConfig& config, [[maybe_unused]] std::vector<std::string>& warnings)
{
- std::vector<std::string> warnings;
auto ctx = std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
if (!ctx) {
#endif
SSL_CTX_set_options(ctx.get(), sslOptions);
- if (!libssl_set_min_tls_version(ctx, config.d_minTLSVersion)) {
+ if (!libssl_set_min_tls_version(*ctx.get(), config.d_minTLSVersion)) {
throw std::runtime_error("Failed to set the minimum version to '" + libssl_tls_version_to_string(config.d_minTLSVersion));
}
session is resumed, causing SSL_get_servername to return nullptr */
SSL_CTX_set_tlsext_servername_callback(ctx.get(), &libssl_server_name_callback);
+ return ctx;
+}
+
+static void mergeNewCertificateAndKey(pdns::libssl::ServerContext& serverContext, pdns::libssl::ServerContext::SharedContext newContext, std::unordered_set<std::string>& names, const std::function<void(pdns::libssl::ServerContext::SharedContext&)>& existingContextCallback)
+{
+ for (const auto& name : names) {
+ auto [existingEntry, inserted] = serverContext.d_sniMap.emplace(name, newContext);
+ if (!inserted) {
+ auto& existingContext = existingEntry->second;
+ existingContextCallback(existingContext);
+ }
+ else if (serverContext.d_sniMap.size() == 1) {
+ serverContext.d_defaultContext = newContext;
+ }
+ }
+}
+
+std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context_no_sni(const TLSConfig& config,
+ [[maybe_unused]] std::map<int, std::string>& ocspResponses)
+{
+ std::vector<std::string> warnings;
+ auto ctx = getNewServerContext(config, warnings);
+
std::vector<int> keyTypes;
/* load certificate and private key */
for (const auto& pair : config.d_certKeyPairs) {
throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.d_key.value());
}
}
+
if (SSL_CTX_check_private_key(ctx.get()) != 1) {
ERR_print_errors_fp(stderr);
throw std::runtime_error("The key from '" + pair.d_key.value() + "' does not match the certificate from '" + pair.d_cert + "'");
}
/* store the type of the new key, we might need it later to select the right OCSP stapling response */
- auto keyType = libssl_get_last_key_type(ctx);
+ auto keyType = libssl_get_last_key_type(*ctx.get());
if (keyType < 0) {
throw std::runtime_error("The key from '" + pair.d_key.value() + "' has an unknown type");
}
return {std::move(ctx), std::move(warnings)};
}
+std::pair<pdns::libssl::ServerContext, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config)
+{
+ std::vector<std::string> warnings;
+ pdns::libssl::ServerContext serverContext;
+
+ std::vector<int> keyTypes;
+ /* load certificate and private key */
+ for (const auto& pair : config.d_certKeyPairs) {
+ auto uniqueCtx = getNewServerContext(config, warnings);
+ auto ctx = std::shared_ptr<SSL_CTX>(uniqueCtx.release(), SSL_CTX_free);
+ if (!pair.d_key) {
+#if defined(HAVE_SSL_CTX_USE_CERT_AND_KEY)
+ // If no separate key is given, treat it as a pkcs12 file
+ auto filePtr = pdns::UniqueFilePtr(fopen(pair.d_cert.c_str(), "r"));
+ if (!filePtr) {
+ throw std::runtime_error("Unable to open file " + pair.d_cert);
+ }
+ auto p12 = std::unique_ptr<PKCS12, void(*)(PKCS12*)>(d2i_PKCS12_fp(filePtr.get(), nullptr), PKCS12_free);
+ if (!p12) {
+ throw std::runtime_error("Unable to open PKCS12 file " + pair.d_cert);
+ }
+ EVP_PKEY *keyptr = nullptr;
+ X509 *certptr = nullptr;
+ STACK_OF(X509) *captr = nullptr;
+ if (!PKCS12_parse(p12.get(), (pair.d_password ? pair.d_password->c_str() : nullptr), &keyptr, &certptr, &captr)) {
+#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3
+ bool failed = true;
+ /* we might be opening a PKCS12 file that uses RC2 CBC or 3DES CBC which, since OpenSSL 3.0.0, requires loading the legacy provider */
+ auto libCtx = OSSL_LIB_CTX_get0_global_default();
+ /* check whether the legacy provider is already loaded */
+ if (!OSSL_PROVIDER_available(libCtx, "legacy")) {
+ /* it's not */
+ auto provider = OSSL_PROVIDER_load(libCtx, "legacy");
+ if (provider != nullptr) {
+ if (PKCS12_parse(p12.get(), (pair.d_password ? pair.d_password->c_str() : nullptr), &keyptr, &certptr, &captr)) {
+ failed = false;
+ }
+ /* we do not want to keep that provider around after that */
+ OSSL_PROVIDER_unload(provider);
+ }
+ }
+ if (failed) {
+#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("An error occured while parsing PKCS12 file " + pair.d_cert);
+#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3
+ }
+#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */
+ }
+ auto key = std::unique_ptr<EVP_PKEY, void(*)(EVP_PKEY*)>(keyptr, EVP_PKEY_free);
+ auto cert = std::unique_ptr<X509, void(*)(X509*)>(certptr, X509_free);
+ auto ca = std::unique_ptr<STACK_OF(X509), void(*)(STACK_OF(X509)*)>(captr, [](STACK_OF(X509)* st){ sk_X509_free(st); });
+
+ auto addCertificateAndKey = [&pair, &key, &cert, &ca](std::shared_ptr<SSL_CTX>& tlsContext) {
+ if (SSL_CTX_use_cert_and_key(tlsContext.get(), cert.get(), key.get(), ca.get(), 1) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("An error occurred while trying to load the TLS certificate and key from PKCS12 file " + pair.d_cert);
+ }
+ };
+
+ addCertificateAndKey(ctx);
+ auto names = get_names_from_last_certificate(*ctx);
+ mergeNewCertificateAndKey(serverContext, ctx, names, addCertificateAndKey);
+#else
+ throw std::runtime_error("PKCS12 files are not supported by your openssl version");
+#endif /* HAVE_SSL_CTX_USE_CERT_AND_KEY */
+ } else {
+ auto addCertificateAndKey = [&pair](std::shared_ptr<SSL_CTX>& tlsContext) {
+ if (SSL_CTX_use_certificate_chain_file(tlsContext.get(), pair.d_cert.c_str()) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.d_cert);
+ }
+ if (SSL_CTX_use_PrivateKey_file(tlsContext.get(), pair.d_key->c_str(), SSL_FILETYPE_PEM) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.d_key.value());
+ }
+ };
+
+ addCertificateAndKey(ctx);
+ auto names = get_names_from_last_certificate(*ctx);
+ mergeNewCertificateAndKey(serverContext, ctx, names, addCertificateAndKey);
+ }
+
+ if (SSL_CTX_check_private_key(ctx.get()) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("The key from '" + pair.d_key.value() + "' does not match the certificate from '" + pair.d_cert + "'");
+ }
+ /* store the type of the new key, we might need it later to select the right OCSP stapling response */
+ auto keyType = libssl_get_last_key_type(*ctx.get());
+ if (keyType < 0) {
+ throw std::runtime_error("The key from '" + pair.d_key.value() + "' has an unknown type");
+ }
+ keyTypes.push_back(keyType);
+ }
+
+#ifndef DISABLE_OCSP_STAPLING
+ if (!config.d_ocspFiles.empty()) {
+ try {
+ serverContext.d_ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, std::move(keyTypes), warnings);
+ }
+ catch(const std::exception& e) {
+ throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what()));
+ }
+ }
+#endif /* DISABLE_OCSP_STAPLING */
+
+ for (auto& entry : serverContext.d_sniMap) {
+ auto& ctx = entry.second;
+ if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) {
+ throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers);
+ }
+
+#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
+ if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) {
+ throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13);
+ }
+#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
+ }
+
+ return {std::move(serverContext), std::move(warnings)};
+}
+
#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
static void libssl_key_log_file_callback(const SSL* ssl, const char* line)
{
#endif
#endif /* DISABLE_OCSP_STAPLING */
-void libssl_set_error_counters_callback(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, TLSErrorCounters* counters);
+void libssl_set_error_counters_callback(SSL_CTX& ctx, TLSErrorCounters* counters);
LibsslTLSVersion libssl_tls_version_from_string(const std::string& str);
const std::string& libssl_tls_version_to_string(LibsslTLSVersion version);
-bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, LibsslTLSVersion version);
+
+
+namespace pdns::libssl {
+class ServerContext
+{
+public:
+ using SharedContext = std::shared_ptr<SSL_CTX>;
+ using SNIToContextMap = std::map<std::string, SharedContext, std::less<>>;
+
+ SharedContext d_defaultContext;
+ SNIToContextMap d_sniMap;
+ std::map<int, std::string> d_ocspResponses;
+};
+}
/* return the created context, and a list of warning messages for issues not severe enough
to trigger raising an exception, like failing to load an OCSP response file */
-std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config,
- std::map<int, std::string>& ocspResponses);
+std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context_no_sni(const TLSConfig& config,
+ std::map<int, std::string>& ocspResponses);
+std::pair<pdns::libssl::ServerContext, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config);
pdns::UniqueFilePtr libssl_set_key_log_file(SSL_CTX* ctx, const std::string& logFile);
#include "libssl.hh"
+static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* arg);
class OpenSSLFrontendContext
{
{
registerOpenSSLUser();
- auto [ctx, warnings] = libssl_init_server_context(tlsConfig, d_ocspResponses);
+ auto [ctx, warnings] = libssl_init_server_context(tlsConfig);
for (const auto& warning : warnings) {
warnlog("%s", warning);
}
- d_tlsCtx = std::move(ctx);
+ d_ocspResponses = std::move(ctx.d_ocspResponses);
+ d_tlsCtx = std::move(ctx.d_defaultContext);
+ d_sniMap = std::move(ctx.d_sniMap);
+ for (auto& entry : d_sniMap) {
+ SSL_CTX_set_tlsext_servername_callback(entry.second.get(), &sni_server_name_callback);
+ }
if (!d_tlsCtx) {
ERR_print_errors_fp(stderr);
OpenSSLTLSTicketKeysRing d_ticketKeys;
std::map<int, std::string> d_ocspResponses;
- std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
+ pdns::libssl::ServerContext::SNIToContextMap d_sniMap;
+ std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr};
pdns::UniqueFilePtr d_keyLogFile{nullptr};
};
+
+static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* /* arg */)
+{
+ const auto* serverName = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+ if (serverName == nullptr) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+ auto* frontendCtx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(ssl));
+ if (frontendCtx == nullptr) {
+ return SSL_TLSEXT_ERR_OK;
+ }
+
+ auto serverNameView = std::string_view(serverName);
+
+ auto it = frontendCtx->d_sniMap.find(serverNameView);
+ if (it == frontendCtx->d_sniMap.end()) {
+ /* keep the default certificate */
+ return SSL_TLSEXT_ERR_OK;
+ }
+
+ /* if it fails there is nothing we can do,
+ let's hope OpenSSL will fallback to the existing,
+ default certificate*/
+ SSL_set_SSL_CTX(ssl, it->second.get());
+ return SSL_TLSEXT_ERR_OK;
+}
+
class OpenSSLSession : public TLSSession
{
public:
d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
- if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
- /* use our own ticket keys handler so we can rotate them */
+ for (auto& entry : d_feContext->d_sniMap) {
+ auto* ctx = entry.second.get();
+ if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
+ /* use our own ticket keys handler so we can rotate them */
#if OPENSSL_VERSION_MAJOR >= 3
- SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+ SSL_CTX_set_tlsext_ticket_key_evp_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
#else
- SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+ SSL_CTX_set_tlsext_ticket_key_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
#endif
- libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
- }
+ libssl_set_ticket_key_callback_data(ctx, d_feContext.get());
+ }
#ifndef DISABLE_OCSP_STAPLING
- if (!d_feContext->d_ocspResponses.empty()) {
- SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
- SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
- }
+ if (!d_feContext->d_ocspResponses.empty()) {
+ SSL_CTX_set_tlsext_status_cb(ctx, &OpenSSLTLSIOCtx::ocspStaplingCb);
+ SSL_CTX_set_tlsext_status_arg(ctx, &d_feContext->d_ocspResponses);
+ }
#endif /* DISABLE_OCSP_STAPLING */
- if (frontend.d_tlsConfig.d_readAhead) {
- SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
- }
+ if (frontend.d_tlsConfig.d_readAhead) {
+ SSL_CTX_set_read_ahead(ctx, 1);
+ }
- libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &frontend.d_tlsCounters);
+ libssl_set_error_counters_callback(*ctx, &frontend.d_tlsCounters);
- libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
+ libssl_set_alpn_select_callback(ctx, alpnServerSelectCallback, this);
- if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
- d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx.get(), frontend.d_tlsConfig.d_keyLogFile);
+ if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
+ d_feContext->d_keyLogFile = libssl_set_key_log_file(ctx, frontend.d_tlsConfig.d_keyLogFile);
+ }
}
try {
/server.csr
/server.key
/server.pem
+/server2.chain
+/server2.csr
+/server2.key
+/server2.pem
/server.p12
+/server-ec.*
/server-doq.*
/server-doh3.*
/server-ocsp.chain
cat server.pem ca.pem > server.chain
# Generate a password-protected PKCS12 file
openssl pkcs12 -export -passout pass:passw0rd -clcerts -in server.pem -CAfile ca.pem -inkey server.key -out server.p12
+ # Generate a second server certificate request
+ openssl req -new -newkey rsa:2048 -nodes -keyout server2.key -out server2.csr -config configServer2.conf
+ # Sign the server cert
+ openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server2.csr -out server2.pem -extfile configServer2.conf -extensions v3_req
+ # Generate a chain
+ cat server2.pem ca.pem > server2.chain
+ # Generate a ECDSA key with P-256
+ openssl ecparam -name secp256r1 -genkey -noout -out server-ec.key
+ # Generate a new server certificate request with the ECDSA key
+ openssl req -new -key server-ec.key -nodes -out server-ec.csr -config configServer.conf
+ # Sign the server cert
+ openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server-ec.csr -out server-ec.pem -extfile configServer.conf -extensions v3_req
+ # Generate a chain
+ cat server-ec.pem ca.pem > server-ec.chain
--- /dev/null
+[req]
+default_bits = 2048
+encrypt_key = no
+prompt = no
+distinguished_name = server_distinguished_name
+req_extensions = v3_req
+
+[server_distinguished_name]
+CN = tls2.tests.dnsdist.org
+OU = PowerDNS.com BV
+countryName = NL
+
+[v3_req]
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+subjectAltName = @alt_names
+
+[alt_names]
+DNS.1 = tls2.tests.dnsdist.org
+IP.2 = 192.0.2.1
cls._response_headers = response_headers.getvalue()
return (receivedQuery, message)
- def sendDOHQueryWrapper(self, query, response, useQueue=True, timeout=2):
- return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
+ def sendDOHQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+ return self.sendDOHQuery(self._dohServerPort, self._serverName if not serverName else serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
- def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True, timeout=2):
- return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
+ def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+ return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName if not serverName else serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
- def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True, timeout=2):
- return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
+ def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+ return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName if not serverName else serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
- def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2):
- return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout)
+ def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+ return self.sendDOTQuery(self._tlsServerPort, self._serverName if not serverName else serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout)
- def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2):
- return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, timeout=timeout)
+ def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+ return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout)
- def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2):
- return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, timeout=timeout)
+ def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+ return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout)
@classmethod
def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
import os
import unittest
import pycurl
+import ssl
from dnsdisttests import DNSDistTest, pickAvailablePort
class TestSNI(DNSDistTest):
_serverKey = 'server.key'
_serverCert = 'server.chain'
+ _serverKeyEC = 'server-ec.key'
+ _serverCertEC = 'server-ec.chain'
+ _serverKey2 = 'server2.key'
+ _serverCert2 = 'server2.chain'
_serverName = 'tls.tests.dnsdist.org'
+ _serverName2 = 'tls2.tests.dnsdist.org'
+ _serverName3 = 'unknown.tests.dnsdist.org'
_caCert = 'ca.pem'
_tlsServerPort = pickAvailablePort()
_dohWithNGHTTP2ServerPort = pickAvailablePort()
_config_template = """
newServer{address="127.0.0.1:%d"}
- addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
- addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
- addDOQLocal("127.0.0.1:%d", "%s", "%s")
- addDOH3Local("127.0.0.1:%d", "%s", "%s")
+ local certs = {"%s", "%s", "%s"}
+ local keys = {"%s", "%s", "%s"}
+ local single_cert = "%s"
+ local single_key = "%s"
+ addTLSLocal("127.0.0.1:%d", certs, keys, { provider="openssl" })
+ addDOHLocal("127.0.0.1:%d", certs, keys, {"/"}, {library="nghttp2"})
+ addDOQLocal("127.0.0.1:%d", single_cert, single_key)
+ addDOH3Local("127.0.0.1:%d", single_cert, single_key)
- function displaySNI(dq)
+ function checkSNI(dq)
local sni = dq:getServerNameIndication()
- if sni ~= '%s' then
+ if tostring(dq.qname) == 'simple.sni.tests.powerdns.com.' and sni ~= '%s' then
return DNSAction.Spoof, '1.2.3.4'
end
+ if tostring(dq.qname) == 'name2.sni.tests.powerdns.com.' and sni ~= '%s' then
+ return DNSAction.Spoof, '2.3.4.5'
+ end
+ if tostring(dq.qname) == 'unknown.sni.tests.powerdns.com.' and sni ~= '%s' then
+ return DNSAction.Spoof, '3.4.5.6'
+ end
+ if tostring(dq.qname) == 'ecdsa.sni.tests.powerdns.com.' and sni ~= '%s' then
+ return DNSAction.Spoof, '4.5.6.7'
+ end
+ if tostring(dq.qname) == 'rsa.sni.tests.powerdns.com.' and sni ~= '%s' then
+ return DNSAction.Spoof, '4.5.6.7'
+ end
return DNSAction.Allow
end
- addAction(AllRule(), LuaAction(displaySNI))
+ addAction(AllRule(), LuaAction(checkSNI))
"""
- _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName']
+ _config_params = ['_testServerPort', '_serverCert', '_serverCertEC', '_serverCert2', '_serverKey', '_serverKeyEC', '_serverKey2', '_serverCert', '_serverKey', '_tlsServerPort', '_dohWithNGHTTP2ServerPort', '_doqServerPort', '_doh3ServerPort', '_serverName', '_serverName2', '_serverName3', '_serverName', '_serverName']
@unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quiche are disabled")
def testServerNameIndicationWithQuiche(self):
self.assertEqual(query, receivedQuery)
self.assertTrue(receivedResponse)
self.assertEqual(response, receivedResponse)
+
+ # check second certificate
+ name = 'name2.sni.tests.powerdns.com.'
+ self._dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (self._serverName2, self._dohWithNGHTTP2ServerPort))
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+ for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]:
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response, timeout=1, serverName=self._serverName2)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertTrue(receivedResponse)
+ self.assertEqual(response, receivedResponse)
+
+ # check SNI for an unkown name, we should get the first certificate
+ name = 'unknown.sni.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ sslctx = ssl.create_default_context(cafile=self._caCert)
+ sslctx.check_hostname = False
+ if hasattr(sslctx, 'set_alpn_protocols'):
+ sslctx.set_alpn_protocols(self._serverName3)
+
+ conn = self.openTLSConnection(self._tlsServerPort, self._serverName3, self._caCert, timeout=1, sslctx=sslctx)
+ self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1)
+ (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1)
+ receivedQuery.id = query.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+
+ cert = conn.getpeercert()
+ subject = cert['subject']
+ altNames = cert['subjectAltName']
+ self.assertEqual(dict(subject[0])['commonName'], 'tls.tests.dnsdist.org')
+ self.assertEqual(dict(subject[1])['organizationalUnitName'], 'PowerDNS.com BV')
+ names = []
+ for entry in altNames:
+ names.append(entry[1])
+ self.assertEqual(names, ['tls.tests.dnsdist.org', 'powerdns.com', '127.0.0.1'])
+
+ # check that we provide the correct RSA/ECDSA certificate when requested
+ for algo in ['rsa', 'ecdsa']:
+ name = algo + '.sni.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ sslctx = ssl.create_default_context(cafile=self._caCert)
+ if hasattr(sslctx, 'set_alpn_protocols'):
+ sslctx.set_alpn_protocols(self._serverName)
+ # disable TLS 1.3 because configuring the signature algorithm is not supported by Python yet
+ sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
+ # explicitly request authentication via RSA or ECDSA
+ sslctx.set_ciphers('a' + algo.upper())
+
+ conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert, timeout=1, sslctx=sslctx)
+ self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1)
+ (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1)
+ receivedQuery.id = query.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+
+ cert = conn.getpeercert()
+ subject = cert['subject']
+ altNames = cert['subjectAltName']
+ self.assertEqual(dict(subject[0])['commonName'], 'tls.tests.dnsdist.org')
+ self.assertEqual(dict(subject[1])['organizationalUnitName'], 'PowerDNS.com BV')
+ names = []
+ for entry in altNames:
+ names.append(entry[1])
+ self.assertEqual(names, ['tls.tests.dnsdist.org', 'powerdns.com', '127.0.0.1'])