]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add support for switching certificates based on SNI w/ OpenSSL
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 4 Apr 2025 13:18:31 +0000 (15:18 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 14 Apr 2025 12:10:27 +0000 (14:10 +0200)
We already supported this with GnuTLS, but OpenSSL does not make it
easy: we need to keep a different `SSL_CTX` object for each certificate/key
and change the `SSL_CTX` associated with an incoming connection to
the correct one based on the Server Name Indication from the servername
callback (actually OpenSSL devs advise to use the ClientHello callback
instead when using a recent enough version of OpenSSL, but the
SNI hostname is not available is not available at this point so we
would have to parse it ourselves, which is a terrible idea, and the
drawbacks are not clear. `nginx` has been getting away with it, so
hopefully we will as well).
One additional issue is that we still need to load certificates
for the same name but different key types (RSA vs ECDSA, for example)
in the same `SSL_CTX` context, which makes the code a bit convoluted.

12 files changed:
pdns/dnsdistdist/dnsdist-configuration-yaml.cc
pdns/dnsdistdist/dnsdist-lua.cc
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/doq-common.cc
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc
regression-tests.dnsdist/.gitignore
regression-tests.dnsdist/Makefile
regression-tests.dnsdist/configServer2.conf [new file with mode: 0644]
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_SNI.py

index 7aea2c681e90caa8f1aa9adbf1e133fc165a3406..edd9c3f788f371bd04c768ad09ddcad1d3a1e9d6 100644 (file)
@@ -228,8 +228,7 @@ static bool validateTLSConfiguration(const dnsdist::rust::settings::BindConfigur
   // we are asked to try to load the certificates so we can return a potential error
   // and properly ignore the frontend before actually launching it
   try {
-    std::map<int, std::string> ocspResponses = {};
-    auto ctx = libssl_init_server_context(tlsConfig, ocspResponses);
+    auto ctx = libssl_init_server_context(tlsConfig);
   }
   catch (const std::runtime_error& e) {
     errlog("Ignoring %s frontend: '%s'", bind.protocol, e.what());
index 3504cbd6a408277776f5cf0f6a5b2ed2ebeb21cc..ef2ac732ae264ac76ce571d15f6db429b8a4a3cf 100644 (file)
@@ -2248,8 +2248,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         // we are asked to try to load the certificates so we can return a potential error
         // and properly ignore the frontend before actually launching it
         try {
-          std::map<int, std::string> ocspResponses = {};
-          auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig, ocspResponses);
+          auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig);
         }
         catch (const std::runtime_error& e) {
           errlog("Ignoring DoH frontend: '%s'", e.what());
@@ -2346,8 +2345,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         // we are asked to try to load the certificates so we can return a potential error
         // and properly ignore the frontend before actually launching it
         try {
-          std::map<int, std::string> ocspResponses = {};
-          auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig, ocspResponses);
+          auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig);
         }
         catch (const std::runtime_error& e) {
           errlog("Ignoring DoH3 frontend: '%s'", e.what());
@@ -2423,8 +2421,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         // we are asked to try to load the certificates so we can return a potential error
         // and properly ignore the frontend before actually launching it
         try {
-          std::map<int, std::string> ocspResponses = {};
-          auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig, ocspResponses);
+          auto ctx = libssl_init_server_context(frontend->d_quicheParams.d_tlsConfig);
         }
         catch (const std::runtime_error& e) {
           errlog("Ignoring DoQ frontend: '%s'", e.what());
@@ -2776,8 +2773,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         // we are asked to try to load the certificates so we can return a potential error
         // and properly ignore the frontend before actually launching it
         try {
-          std::map<int, std::string> ocspResponses = {};
-          auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
+          auto ctx = libssl_init_server_context(frontend->d_tlsConfig);
         }
         catch (const std::runtime_error& e) {
           errlog("Ignoring TLS frontend: '%s'", e.what());
index fc03c86430431b5196b0150779b61b46c7337677..3eb1bccd548ae99073e8e581677e52644f9e3157 100644 (file)
@@ -1482,7 +1482,7 @@ static void setupTLSContext(DOHAcceptContext& acceptCtx,
     tlsConfig.d_ciphers = DOH_DEFAULT_CIPHERS.data();
   }
 
-  auto [ctx, warnings] = libssl_init_server_context(tlsConfig, acceptCtx.d_ocspResponses);
+  auto [ctx, warnings] = libssl_init_server_context_no_sni(tlsConfig, acceptCtx.d_ocspResponses);
   for (const auto& warning : warnings) {
     warnlog("%s", warning);
   }
@@ -1504,7 +1504,7 @@ static void setupTLSContext(DOHAcceptContext& acceptCtx,
   }
 #endif /* DISABLE_OCSP_STAPLING */
 
-  libssl_set_error_counters_callback(ctx, &counters);
+  libssl_set_error_counters_callback(*ctx.get(), &counters);
 
   if (!tlsConfig.d_keyLogFile.empty()) {
     acceptCtx.d_keyLogFile = libssl_set_key_log_file(ctx.get(), tlsConfig.d_keyLogFile);
index ce2993dcd66bc62369b592a3e4acb4040becff1c..f4b58f7d4832e91f55f757345e78a90b23bf41ab 100644 (file)
@@ -215,12 +215,12 @@ void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHT
   for (const auto& pair : params.d_tlsConfig.d_certKeyPairs) {
     auto res = quiche_config_load_cert_chain_from_pem_file(config.get(), pair.d_cert.c_str());
     if (res != 0) {
-      throw std::runtime_error("Error loading the server certificate: " + std::to_string(res));
+      throw std::runtime_error("Error loading the server certificate from '" + pair.d_cert + "': " + std::to_string(res));
     }
     if (pair.d_key) {
       res = quiche_config_load_priv_key_from_pem_file(config.get(), pair.d_key->c_str());
       if (res != 0) {
-        throw std::runtime_error("Error loading the server key: " + std::to_string(res));
+        throw std::runtime_error("Error loading the server key from '" + *(pair.d_key) + "': " + std::to_string(res));
       }
     }
   }
index 8b7b3e5de291b7e2b12df9c55eec8f63fb284d51..cd386b7f9d4e7a9aa61c9004290934701869633a 100644 (file)
@@ -28,6 +28,7 @@
 #endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */
 #include <openssl/rand.h>
 #include <openssl/ssl.h>
+#include <openssl/x509v3.h>
 #include <fcntl.h>
 
 #if OPENSSL_VERSION_MAJOR >= 3
@@ -353,10 +354,10 @@ static void libssl_info_callback(const SSL *ssl, int where, int /* ret */)
   }
 }
 
-void libssl_set_error_counters_callback(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, TLSErrorCounters* counters)
+void libssl_set_error_counters_callback(SSL_CTX& ctx, TLSErrorCounters* counters)
 {
-  SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters);
-  SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback);
+  SSL_CTX_set_ex_data(&ctx, s_countersIndex, counters);
+  SSL_CTX_set_info_callback(&ctx, libssl_info_callback);
 }
 
 #ifndef DISABLE_OCSP_STAPLING
@@ -518,12 +519,12 @@ bool libssl_generate_ocsp_response(const std::string& certFile, const std::strin
 #endif /* HAVE_OCSP_BASIC_SIGN */
 #endif /* DISABLE_OCSP_STAPLING */
 
-static int libssl_get_last_key_type(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx)
+static int libssl_get_last_key_type(SSL_CTX& ctx)
 {
 #ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY
-  auto pkey = SSL_CTX_get0_privatekey(ctx.get());
+  auto pkey = SSL_CTX_get0_privatekey(&ctx);
 #else
-  auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(ctx.get()), SSL_free);
+  auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(&ctx), SSL_free);
   if (!temp) {
     return -1;
   }
@@ -537,6 +538,61 @@ static int libssl_get_last_key_type(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_f
   return EVP_PKEY_base_id(pkey);
 }
 
+struct StackOfNamesDeleter
+{
+  void operator()(STACK_OF(GENERAL_NAME)* ptr) const noexcept {
+    sk_GENERAL_NAME_pop_free(ptr, GENERAL_NAME_free);
+  }
+};
+
+static std::unordered_set<std::string> get_names_from_certificate(const X509* certificate)
+{
+  std::unordered_set<std::string> result;
+  auto names = std::unique_ptr<STACK_OF(GENERAL_NAME), StackOfNamesDeleter>(static_cast<STACK_OF(GENERAL_NAME)*>(X509_get_ext_d2i(certificate, NID_subject_alt_name, nullptr, nullptr)));
+  if (names) {
+    for (int idx = 0; idx < sk_GENERAL_NAME_num(names.get()); idx++) {
+      const auto* name = sk_GENERAL_NAME_value(names.get(), idx);
+      if (name->type != GEN_DNS) {
+        /* ignore GEN_IPADD / name->d.iPAddress (raw IP address bytes), it cannot be used in SNI anyway */
+        continue;
+      }
+      unsigned char* str = nullptr;
+      if (ASN1_STRING_to_UTF8(&str, name->d.dNSName) < 0) {
+        continue;
+      }
+      result.emplace(reinterpret_cast<const char*>(str));
+      OPENSSL_free(str);
+    }
+  }
+
+  auto* name = X509_get_subject_name(certificate);
+  if (name != nullptr) {
+    ssize_t idx = -1;
+    while ((idx = X509_NAME_get_index_by_NID(name, NID_commonName, idx)) != -1) {
+      const auto* entry = X509_NAME_get_entry(name, idx);
+      const auto* value = X509_NAME_ENTRY_get_data(entry);
+      unsigned char* str = nullptr;
+      if (ASN1_STRING_to_UTF8(&str, value) < 0) {
+        continue;
+      }
+      result.emplace(reinterpret_cast<const char*>(str));
+      OPENSSL_free(str);
+    }
+  }
+
+  return result;
+}
+
+static std::unordered_set<std::string> get_names_from_last_certificate(const SSL_CTX& ctx)
+{
+  const auto* cert = SSL_CTX_get0_certificate(&ctx);
+  if (cert == nullptr) {
+    return {};
+  }
+
+  return get_names_from_certificate(cert);
+}
+
 LibsslTLSVersion libssl_tls_version_from_string(const std::string& str)
 {
   if (str == "tls1.0") {
@@ -570,7 +626,7 @@ const std::string& libssl_tls_version_to_string(LibsslTLSVersion version)
   return it->second;
 }
 
-bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, LibsslTLSVersion version)
+static bool libssl_set_min_tls_version(SSL_CTX& ctx, LibsslTLSVersion version)
 {
 #if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version)
   /* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated
@@ -597,7 +653,7 @@ bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)
     return false;
   }
 
-  if (SSL_CTX_set_min_proto_version(ctx.get(), vers) != 1) {
+  if (SSL_CTX_set_min_proto_version(&ctx, vers) != 1) {
     return false;
   }
   return true;
@@ -619,8 +675,8 @@ bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)
     return false;
   }
 
-  long options = SSL_CTX_get_options(ctx.get());
-  SSL_CTX_set_options(ctx.get(), options | vers);
+  long options = SSL_CTX_get_options(&ctx);
+  SSL_CTX_set_options(&ctx, options | vers);
   return true;
 #endif
 }
@@ -894,10 +950,8 @@ bool OpenSSLTLSTicketKey::decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx,
   return true;
 }
 
-std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config,
-                                                                                                                  [[maybe_unused]] std::map<int, std::string>& ocspResponses)
+static std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> getNewServerContext(const TLSConfig& config, [[maybe_unused]] std::vector<std::string>& warnings)
 {
-  std::vector<std::string> warnings;
   auto ctx = std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
 
   if (!ctx) {
@@ -952,7 +1006,7 @@ std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::st
 #endif
 
   SSL_CTX_set_options(ctx.get(), sslOptions);
-  if (!libssl_set_min_tls_version(ctx, config.d_minTLSVersion)) {
+  if (!libssl_set_min_tls_version(*ctx.get(), config.d_minTLSVersion)) {
     throw std::runtime_error("Failed to set the minimum version to '" + libssl_tls_version_to_string(config.d_minTLSVersion));
   }
 
@@ -992,6 +1046,29 @@ std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::st
      session is resumed, causing SSL_get_servername to return nullptr */
   SSL_CTX_set_tlsext_servername_callback(ctx.get(), &libssl_server_name_callback);
 
+  return ctx;
+}
+
+static void mergeNewCertificateAndKey(pdns::libssl::ServerContext& serverContext, pdns::libssl::ServerContext::SharedContext newContext, std::unordered_set<std::string>& names, const std::function<void(pdns::libssl::ServerContext::SharedContext&)>& existingContextCallback)
+{
+  for (const auto& name : names) {
+    auto [existingEntry, inserted] = serverContext.d_sniMap.emplace(name, newContext);
+    if (!inserted) {
+      auto& existingContext = existingEntry->second;
+      existingContextCallback(existingContext);
+    }
+    else if (serverContext.d_sniMap.size() == 1) {
+      serverContext.d_defaultContext = newContext;
+    }
+  }
+}
+
+std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context_no_sni(const TLSConfig& config,
+                                                                                                                         [[maybe_unused]] std::map<int, std::string>& ocspResponses)
+{
+  std::vector<std::string> warnings;
+  auto ctx = getNewServerContext(config, warnings);
+
   std::vector<int> keyTypes;
   /* load certificate and private key */
   for (const auto& pair : config.d_certKeyPairs) {
@@ -1055,12 +1132,13 @@ std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::st
         throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.d_key.value());
       }
     }
+
     if (SSL_CTX_check_private_key(ctx.get()) != 1) {
       ERR_print_errors_fp(stderr);
       throw std::runtime_error("The key from '" + pair.d_key.value() + "' does not match the certificate from '" + pair.d_cert + "'");
     }
     /* store the type of the new key, we might need it later to select the right OCSP stapling response */
-    auto keyType = libssl_get_last_key_type(ctx);
+    auto keyType = libssl_get_last_key_type(*ctx.get());
     if (keyType < 0) {
       throw std::runtime_error("The key from '" + pair.d_key.value() + "' has an unknown type");
     }
@@ -1091,6 +1169,128 @@ std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::st
   return {std::move(ctx), std::move(warnings)};
 }
 
+std::pair<pdns::libssl::ServerContext, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config)
+{
+  std::vector<std::string> warnings;
+  pdns::libssl::ServerContext serverContext;
+
+  std::vector<int> keyTypes;
+  /* load certificate and private key */
+  for (const auto& pair : config.d_certKeyPairs) {
+    auto uniqueCtx = getNewServerContext(config, warnings);
+    auto ctx = std::shared_ptr<SSL_CTX>(uniqueCtx.release(), SSL_CTX_free);
+    if (!pair.d_key) {
+#if defined(HAVE_SSL_CTX_USE_CERT_AND_KEY)
+      // If no separate key is given, treat it as a pkcs12 file
+      auto filePtr = pdns::UniqueFilePtr(fopen(pair.d_cert.c_str(), "r"));
+      if (!filePtr) {
+        throw std::runtime_error("Unable to open file " + pair.d_cert);
+      }
+      auto p12 = std::unique_ptr<PKCS12, void(*)(PKCS12*)>(d2i_PKCS12_fp(filePtr.get(), nullptr), PKCS12_free);
+      if (!p12) {
+        throw std::runtime_error("Unable to open PKCS12 file " + pair.d_cert);
+      }
+      EVP_PKEY *keyptr = nullptr;
+      X509 *certptr = nullptr;
+      STACK_OF(X509) *captr = nullptr;
+      if (!PKCS12_parse(p12.get(), (pair.d_password ? pair.d_password->c_str() : nullptr), &keyptr, &certptr, &captr)) {
+#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3
+        bool failed = true;
+        /* we might be opening a PKCS12 file that uses RC2 CBC or 3DES CBC which, since OpenSSL 3.0.0, requires loading the legacy provider */
+        auto libCtx = OSSL_LIB_CTX_get0_global_default();
+        /* check whether the legacy provider is already loaded */
+        if (!OSSL_PROVIDER_available(libCtx, "legacy")) {
+          /* it's not */
+          auto provider = OSSL_PROVIDER_load(libCtx, "legacy");
+          if (provider != nullptr) {
+            if (PKCS12_parse(p12.get(), (pair.d_password ? pair.d_password->c_str() : nullptr), &keyptr, &certptr, &captr)) {
+              failed = false;
+            }
+            /* we do not want to keep that provider around after that */
+            OSSL_PROVIDER_unload(provider);
+          }
+        }
+        if (failed) {
+#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */
+          ERR_print_errors_fp(stderr);
+          throw std::runtime_error("An error occured while parsing PKCS12 file " + pair.d_cert);
+#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3
+        }
+#endif /* defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 */
+      }
+      auto key = std::unique_ptr<EVP_PKEY, void(*)(EVP_PKEY*)>(keyptr, EVP_PKEY_free);
+      auto cert = std::unique_ptr<X509, void(*)(X509*)>(certptr, X509_free);
+      auto ca = std::unique_ptr<STACK_OF(X509), void(*)(STACK_OF(X509)*)>(captr, [](STACK_OF(X509)* st){ sk_X509_free(st); });
+
+      auto addCertificateAndKey = [&pair, &key, &cert, &ca](std::shared_ptr<SSL_CTX>& tlsContext) {
+        if (SSL_CTX_use_cert_and_key(tlsContext.get(), cert.get(), key.get(), ca.get(), 1) != 1) {
+          ERR_print_errors_fp(stderr);
+          throw std::runtime_error("An error occurred while trying to load the TLS certificate and key from PKCS12 file " + pair.d_cert);
+        }
+      };
+
+      addCertificateAndKey(ctx);
+      auto names = get_names_from_last_certificate(*ctx);
+      mergeNewCertificateAndKey(serverContext, ctx, names, addCertificateAndKey);
+#else
+      throw std::runtime_error("PKCS12 files are not supported by your openssl version");
+#endif /* HAVE_SSL_CTX_USE_CERT_AND_KEY */
+    } else {
+      auto addCertificateAndKey = [&pair](std::shared_ptr<SSL_CTX>& tlsContext) {
+        if (SSL_CTX_use_certificate_chain_file(tlsContext.get(), pair.d_cert.c_str()) != 1) {
+          ERR_print_errors_fp(stderr);
+          throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.d_cert);
+        }
+        if (SSL_CTX_use_PrivateKey_file(tlsContext.get(), pair.d_key->c_str(), SSL_FILETYPE_PEM) != 1) {
+          ERR_print_errors_fp(stderr);
+          throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.d_key.value());
+        }
+      };
+
+      addCertificateAndKey(ctx);
+      auto names = get_names_from_last_certificate(*ctx);
+      mergeNewCertificateAndKey(serverContext, ctx, names, addCertificateAndKey);
+    }
+
+    if (SSL_CTX_check_private_key(ctx.get()) != 1) {
+      ERR_print_errors_fp(stderr);
+      throw std::runtime_error("The key from '" + pair.d_key.value() + "' does not match the certificate from '" + pair.d_cert + "'");
+    }
+    /* store the type of the new key, we might need it later to select the right OCSP stapling response */
+    auto keyType = libssl_get_last_key_type(*ctx.get());
+    if (keyType < 0) {
+      throw std::runtime_error("The key from '" + pair.d_key.value() + "' has an unknown type");
+    }
+    keyTypes.push_back(keyType);
+ }
+
+#ifndef DISABLE_OCSP_STAPLING
+  if (!config.d_ocspFiles.empty()) {
+    try {
+      serverContext.d_ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, std::move(keyTypes), warnings);
+    }
+    catch(const std::exception& e) {
+      throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what()));
+    }
+  }
+#endif /* DISABLE_OCSP_STAPLING */
+
+  for (auto& entry : serverContext.d_sniMap) {
+    auto& ctx = entry.second;
+    if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) {
+      throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers);
+    }
+
+#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
+    if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) {
+      throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13);
+    }
+#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
+  }
+
+  return {std::move(serverContext), std::move(warnings)};
+}
+
 #ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
 static void libssl_key_log_file_callback(const SSL* ssl, const char* line)
 {
index 3d3e5a8eba48e49fda9e5917ba5c607d99ddf2db..fd545019e401be21fe2e3707f661e3c8f4d6626f 100644 (file)
@@ -147,16 +147,30 @@ bool libssl_generate_ocsp_response(const std::string& certFile, const std::strin
 #endif
 #endif /* DISABLE_OCSP_STAPLING */
 
-void libssl_set_error_counters_callback(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, TLSErrorCounters* counters);
+void libssl_set_error_counters_callback(SSL_CTX& ctx, TLSErrorCounters* counters);
 
 LibsslTLSVersion libssl_tls_version_from_string(const std::string& str);
 const std::string& libssl_tls_version_to_string(LibsslTLSVersion version);
-bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>& ctx, LibsslTLSVersion version);
+
+
+namespace pdns::libssl {
+class ServerContext
+{
+public:
+  using SharedContext = std::shared_ptr<SSL_CTX>;
+  using SNIToContextMap  = std::map<std::string, SharedContext, std::less<>>;
+
+  SharedContext d_defaultContext;
+  SNIToContextMap d_sniMap;
+  std::map<int, std::string> d_ocspResponses;
+};
+}
 
 /* return the created context, and a list of warning messages for issues not severe enough
    to trigger raising an exception, like failing to load an OCSP response file */
-std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config,
-                                                                                                            std::map<int, std::string>& ocspResponses);
+std::pair<std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)>, std::vector<std::string>> libssl_init_server_context_no_sni(const TLSConfig& config,
+                                                                                                                         std::map<int, std::string>& ocspResponses);
+std::pair<pdns::libssl::ServerContext, std::vector<std::string>> libssl_init_server_context(const TLSConfig& config);
 
 pdns::UniqueFilePtr libssl_set_key_log_file(SSL_CTX* ctx, const std::string& logFile);
 
index 9379576b6088706a7283f2f1bbc413add184b2c2..60cb189258079a2ad05030794f429f617db8c93f 100644 (file)
@@ -57,6 +57,7 @@ bool shouldDoVerboseLogging()
 
 #include "libssl.hh"
 
+static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* arg);
 
 class OpenSSLFrontendContext
 {
@@ -65,11 +66,16 @@ public:
   {
     registerOpenSSLUser();
 
-    auto [ctx, warnings] = libssl_init_server_context(tlsConfig, d_ocspResponses);
+    auto [ctx, warnings] = libssl_init_server_context(tlsConfig);
     for (const auto& warning : warnings) {
       warnlog("%s", warning);
     }
-    d_tlsCtx = std::move(ctx);
+    d_ocspResponses = std::move(ctx.d_ocspResponses);
+    d_tlsCtx = std::move(ctx.d_defaultContext);
+    d_sniMap = std::move(ctx.d_sniMap);
+    for (auto& entry : d_sniMap) {
+      SSL_CTX_set_tlsext_servername_callback(entry.second.get(), &sni_server_name_callback);
+    }
 
     if (!d_tlsCtx) {
       ERR_print_errors_fp(stderr);
@@ -86,10 +92,38 @@ public:
 
   OpenSSLTLSTicketKeysRing d_ticketKeys;
   std::map<int, std::string> d_ocspResponses;
-  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
+  pdns::libssl::ServerContext::SNIToContextMap d_sniMap;
+  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr};
   pdns::UniqueFilePtr d_keyLogFile{nullptr};
 };
 
+
+static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* /* arg */)
+{
+  const auto* serverName = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+  if (serverName == nullptr) {
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+  auto* frontendCtx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(ssl));
+  if (frontendCtx == nullptr) {
+    return SSL_TLSEXT_ERR_OK;
+  }
+
+  auto serverNameView = std::string_view(serverName);
+
+  auto it = frontendCtx->d_sniMap.find(serverNameView);
+  if (it == frontendCtx->d_sniMap.end()) {
+    /* keep the default certificate */
+    return SSL_TLSEXT_ERR_OK;
+  }
+
+  /* if it fails there is nothing we can do,
+     let's hope OpenSSL will fallback to the existing,
+     default certificate*/
+  SSL_set_SSL_CTX(ssl, it->second.get());
+  return SSL_TLSEXT_ERR_OK;
+}
+
 class OpenSSLSession : public TLSSession
 {
 public:
@@ -649,33 +683,36 @@ public:
 
     d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
 
-    if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
-      /* use our own ticket keys handler so we can rotate them */
+    for (auto& entry : d_feContext->d_sniMap) {
+      auto* ctx = entry.second.get();
+      if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
+        /* use our own ticket keys handler so we can rotate them */
 #if OPENSSL_VERSION_MAJOR >= 3
-      SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+        SSL_CTX_set_tlsext_ticket_key_evp_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
 #else
-      SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+        SSL_CTX_set_tlsext_ticket_key_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
 #endif
-      libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
-    }
+        libssl_set_ticket_key_callback_data(ctx, d_feContext.get());
+      }
 
 #ifndef DISABLE_OCSP_STAPLING
-    if (!d_feContext->d_ocspResponses.empty()) {
-      SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
-      SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
-    }
+      if (!d_feContext->d_ocspResponses.empty()) {
+        SSL_CTX_set_tlsext_status_cb(ctx, &OpenSSLTLSIOCtx::ocspStaplingCb);
+        SSL_CTX_set_tlsext_status_arg(ctx, &d_feContext->d_ocspResponses);
+      }
 #endif /* DISABLE_OCSP_STAPLING */
 
-    if (frontend.d_tlsConfig.d_readAhead) {
-      SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
-    }
+      if (frontend.d_tlsConfig.d_readAhead) {
+        SSL_CTX_set_read_ahead(ctx, 1);
+      }
 
-    libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &frontend.d_tlsCounters);
+      libssl_set_error_counters_callback(*ctx, &frontend.d_tlsCounters);
 
-    libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
+      libssl_set_alpn_select_callback(ctx, alpnServerSelectCallback, this);
 
-    if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
-      d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx.get(), frontend.d_tlsConfig.d_keyLogFile);
+      if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
+        d_feContext->d_keyLogFile = libssl_set_key_log_file(ctx, frontend.d_tlsConfig.d_keyLogFile);
+      }
     }
 
     try {
index 0b347c4993b5ba235eb1fe19dc6c77e37fccc182..f5c450fbbd4ba282757c0f25cdb356af46f0b186 100644 (file)
 /server.csr
 /server.key
 /server.pem
+/server2.chain
+/server2.csr
+/server2.key
+/server2.pem
 /server.p12
+/server-ec.*
 /server-doq.*
 /server-doh3.*
 /server-ocsp.chain
index 84286d7a4a95fc6996860447f6e9066ce27babb8..e851c8c1492b29530b6cfee82195a45fb55d7cd2 100644 (file)
@@ -13,3 +13,17 @@ certs:
        cat server.pem ca.pem > server.chain
        # Generate a password-protected PKCS12 file
        openssl pkcs12 -export -passout pass:passw0rd -clcerts -in server.pem -CAfile ca.pem -inkey server.key -out server.p12
+       # Generate a second server certificate request
+       openssl req -new -newkey rsa:2048 -nodes -keyout server2.key -out server2.csr -config configServer2.conf
+       # Sign the server cert
+       openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server2.csr -out server2.pem -extfile configServer2.conf -extensions v3_req
+       # Generate a chain
+       cat server2.pem ca.pem > server2.chain
+       # Generate a ECDSA key with P-256
+       openssl ecparam -name secp256r1 -genkey -noout -out server-ec.key
+       # Generate a new server certificate request with the ECDSA key
+       openssl req -new -key server-ec.key -nodes -out server-ec.csr -config configServer.conf
+       # Sign the server cert
+       openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server-ec.csr -out server-ec.pem -extfile configServer.conf -extensions v3_req
+       # Generate a chain
+       cat server-ec.pem ca.pem > server-ec.chain
diff --git a/regression-tests.dnsdist/configServer2.conf b/regression-tests.dnsdist/configServer2.conf
new file mode 100644 (file)
index 0000000..1208d58
--- /dev/null
@@ -0,0 +1,20 @@
+[req]
+default_bits = 2048
+encrypt_key = no
+prompt = no
+distinguished_name = server_distinguished_name
+req_extensions = v3_req
+
+[server_distinguished_name]
+CN = tls2.tests.dnsdist.org
+OU = PowerDNS.com BV
+countryName = NL
+
+[v3_req]
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+subjectAltName = @alt_names
+
+[alt_names]
+DNS.1 = tls2.tests.dnsdist.org
+IP.2 = 192.0.2.1
index 53c97b04b476a20634588f6dfe7cdb952584dac2..c14ab88310ad52856801d02c769ffc15d7f48d67 100644 (file)
@@ -1129,23 +1129,23 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         cls._response_headers = response_headers.getvalue()
         return (receivedQuery, message)
 
-    def sendDOHQueryWrapper(self, query, response, useQueue=True, timeout=2):
-        return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
+    def sendDOHQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+        return self.sendDOHQuery(self._dohServerPort, self._serverName if not serverName else serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
 
-    def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True, timeout=2):
-        return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
+    def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+        return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName if not serverName else serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
 
-    def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True, timeout=2):
-        return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
+    def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+        return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName if not serverName else serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, timeout=timeout)
 
-    def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2):
-        return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout)
+    def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+        return self.sendDOTQuery(self._tlsServerPort, self._serverName if not serverName else serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout)
 
-    def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2):
-        return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, timeout=timeout)
+    def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+        return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout)
 
-    def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2):
-        return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, timeout=timeout)
+    def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
+        return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout)
     @classmethod
     def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
 
index 1e93cc6ce463823467ad705a1e15296a71c3dd28..31ac684d47e8c28c9b3e5f88d08dcc0fd57b51d1 100644 (file)
@@ -4,13 +4,20 @@ import dns
 import os
 import unittest
 import pycurl
+import ssl
 
 from dnsdisttests import DNSDistTest, pickAvailablePort
 
 class TestSNI(DNSDistTest):
     _serverKey = 'server.key'
     _serverCert = 'server.chain'
+    _serverKeyEC = 'server-ec.key'
+    _serverCertEC = 'server-ec.chain'
+    _serverKey2 = 'server2.key'
+    _serverCert2 = 'server2.chain'
     _serverName = 'tls.tests.dnsdist.org'
+    _serverName2 = 'tls2.tests.dnsdist.org'
+    _serverName3 = 'unknown.tests.dnsdist.org'
     _caCert = 'ca.pem'
     _tlsServerPort = pickAvailablePort()
     _dohWithNGHTTP2ServerPort = pickAvailablePort()
@@ -22,21 +29,37 @@ class TestSNI(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%d"}
 
-    addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
-    addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
-    addDOQLocal("127.0.0.1:%d", "%s", "%s")
-    addDOH3Local("127.0.0.1:%d", "%s", "%s")
+    local certs = {"%s", "%s", "%s"}
+    local keys = {"%s", "%s", "%s"}
+    local single_cert = "%s"
+    local single_key = "%s"
+    addTLSLocal("127.0.0.1:%d", certs, keys, { provider="openssl" })
+    addDOHLocal("127.0.0.1:%d", certs, keys, {"/"}, {library="nghttp2"})
+    addDOQLocal("127.0.0.1:%d", single_cert, single_key)
+    addDOH3Local("127.0.0.1:%d", single_cert, single_key)
 
-    function displaySNI(dq)
+    function checkSNI(dq)
       local sni = dq:getServerNameIndication()
-      if sni ~= '%s' then
+      if tostring(dq.qname) == 'simple.sni.tests.powerdns.com.' and sni ~= '%s' then
         return DNSAction.Spoof, '1.2.3.4'
       end
+      if tostring(dq.qname) == 'name2.sni.tests.powerdns.com.' and sni ~= '%s' then
+        return DNSAction.Spoof, '2.3.4.5'
+      end
+      if tostring(dq.qname) == 'unknown.sni.tests.powerdns.com.' and sni ~= '%s' then
+        return DNSAction.Spoof, '3.4.5.6'
+      end
+      if tostring(dq.qname) == 'ecdsa.sni.tests.powerdns.com.' and sni ~= '%s' then
+        return DNSAction.Spoof, '4.5.6.7'
+      end
+      if tostring(dq.qname) == 'rsa.sni.tests.powerdns.com.' and sni ~= '%s' then
+        return DNSAction.Spoof, '4.5.6.7'
+      end
       return DNSAction.Allow
     end
-    addAction(AllRule(), LuaAction(displaySNI))
+    addAction(AllRule(), LuaAction(checkSNI))
     """
-    _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName']
+    _config_params = ['_testServerPort', '_serverCert', '_serverCertEC', '_serverCert2', '_serverKey', '_serverKeyEC', '_serverKey2', '_serverCert', '_serverKey', '_tlsServerPort', '_dohWithNGHTTP2ServerPort', '_doqServerPort', '_doh3ServerPort', '_serverName', '_serverName2', '_serverName3', '_serverName', '_serverName']
 
     @unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quiche are disabled")
     def testServerNameIndicationWithQuiche(self):
@@ -79,3 +102,93 @@ class TestSNI(DNSDistTest):
             self.assertEqual(query, receivedQuery)
             self.assertTrue(receivedResponse)
             self.assertEqual(response, receivedResponse)
+
+        # check second certificate
+        name = 'name2.sni.tests.powerdns.com.'
+        self._dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (self._serverName2, self._dohWithNGHTTP2ServerPort))
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]:
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response, timeout=1, serverName=self._serverName2)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(response, receivedResponse)
+
+        # check SNI for an unkown name, we should get the first certificate
+        name = 'unknown.sni.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        sslctx = ssl.create_default_context(cafile=self._caCert)
+        sslctx.check_hostname = False
+        if hasattr(sslctx, 'set_alpn_protocols'):
+            sslctx.set_alpn_protocols(self._serverName3)
+
+        conn = self.openTLSConnection(self._tlsServerPort, self._serverName3, self._caCert, timeout=1, sslctx=sslctx)
+        self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1)
+        (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1)
+        receivedQuery.id = query.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+
+        cert = conn.getpeercert()
+        subject = cert['subject']
+        altNames = cert['subjectAltName']
+        self.assertEqual(dict(subject[0])['commonName'], 'tls.tests.dnsdist.org')
+        self.assertEqual(dict(subject[1])['organizationalUnitName'], 'PowerDNS.com BV')
+        names = []
+        for entry in altNames:
+            names.append(entry[1])
+        self.assertEqual(names, ['tls.tests.dnsdist.org', 'powerdns.com', '127.0.0.1'])
+
+        # check that we provide the correct RSA/ECDSA certificate when requested
+        for algo in ['rsa', 'ecdsa']:
+            name = algo + '.sni.tests.powerdns.com.'
+            query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        3600,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '127.0.0.1')
+            response.answer.append(rrset)
+
+            sslctx = ssl.create_default_context(cafile=self._caCert)
+            if hasattr(sslctx, 'set_alpn_protocols'):
+                sslctx.set_alpn_protocols(self._serverName)
+            # disable TLS 1.3 because configuring the signature algorithm is not supported by Python yet
+            sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
+            # explicitly request authentication via RSA or ECDSA
+            sslctx.set_ciphers('a' + algo.upper())
+
+            conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert, timeout=1, sslctx=sslctx)
+            self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1)
+            (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1)
+            receivedQuery.id = query.id
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
+
+            cert = conn.getpeercert()
+            subject = cert['subject']
+            altNames = cert['subjectAltName']
+            self.assertEqual(dict(subject[0])['commonName'], 'tls.tests.dnsdist.org')
+            self.assertEqual(dict(subject[1])['organizationalUnitName'], 'PowerDNS.com BV')
+            names = []
+            for entry in altNames:
+                names.append(entry[1])
+            self.assertEqual(names, ['tls.tests.dnsdist.org', 'powerdns.com', '127.0.0.1'])