]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Use per-thread credentials for GnuTLS client connections 10841/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 13 Oct 2021 12:03:45 +0000 (14:03 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 13 Oct 2021 12:03:45 +0000 (14:03 +0200)
It looks like there is a race in some versions when the credentials
are shared between several threads opening TLS client connections.

pdns/dnsdistdist/dnsdist-tsan.supp
pdns/iputils.hh
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc

index 39ffe6cd485e7b3616a7ea7b3dc2ecf2e7148d73..c8c2d50e340d9e759d251a56dcd14e690eaf883f 100644 (file)
@@ -11,3 +11,4 @@ race:DownstreamState::setDown
 race:DownstreamState::setUp
 race:DownstreamState::setAuto
 race:updateHealthCheckResult
+race:carbonDumpThread
index 42e7a2bffebad012b30dd7fdfc8e8d785eb4fecc..45690731a7b5da600d83a2f0f540c248b41ba2b0 100644 (file)
@@ -30,7 +30,6 @@
 #include <bitset>
 #include "pdnsexception.hh"
 #include "misc.hh"
-#include <sys/socket.h>
 #include <netdb.h>
 #include <sstream>
 #include <boost/tuple/tuple.hpp>
index 19c6d78a7390b55b885c9eead06631e2ba8df6b1..2c0e8be68251e9a9a142f6dd6589f33864cfca16 100644 (file)
@@ -804,21 +804,21 @@ std::unique_ptr<FILE, int(*)(FILE*)> libssl_set_key_log_file(std::unique_ptr<SSL
 }
 
 /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
-void libssl_set_npn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
+void libssl_set_npn_select_callback(SSL_CTX* ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
 {
 #ifdef HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB
-  SSL_CTX_set_next_proto_select_cb(ctx.get(), cb, arg);
+  SSL_CTX_set_next_proto_select_cb(ctx, cb, arg);
 #endif
 }
 
-void libssl_set_alpn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
+void libssl_set_alpn_select_callback(SSL_CTX* ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
 {
 #ifdef HAVE_SSL_CTX_SET_ALPN_SELECT_CB
-  SSL_CTX_set_alpn_select_cb(ctx.get(), cb, arg);
+  SSL_CTX_set_alpn_select_cb(ctx, cb, arg);
 #endif
 }
 
-bool libssl_set_alpn_protos(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::vector<std::vector<uint8_t>>& protos)
+bool libssl_set_alpn_protos(SSL_CTX* ctx, const std::vector<std::vector<uint8_t>>& protos)
 {
 #ifdef HAVE_SSL_CTX_SET_ALPN_PROTOS
   std::vector<uint8_t> wire;
@@ -830,7 +830,7 @@ bool libssl_set_alpn_protos(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, co
     wire.push_back(length);
     wire.insert(wire.end(), proto.begin(), proto.end());
   }
-  return SSL_CTX_set_alpn_protos(ctx.get(), wire.data(), wire.size()) == 0;
+  return SSL_CTX_set_alpn_protos(ctx, wire.data(), wire.size()) == 0;
 #else
   return false;
 #endif
index 2af0f4ef8947ec96523b3818e540a35876ee32db..4561b4a496f10b7671f42a00d564f421e03104aa 100644 (file)
@@ -127,11 +127,11 @@ std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> libssl_init_server_context(const TLS
 std::unique_ptr<FILE, int(*)(FILE*)> libssl_set_key_log_file(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::string& logFile);
 
 /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
-void libssl_set_npn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg);
+void libssl_set_npn_select_callback(SSL_CTX* ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg);
 /* called in a server context, to select an ALPN value advertised by the client if any */
-void libssl_set_alpn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg);
+void libssl_set_alpn_select_callback(SSL_CTX* ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg);
 /* set the supported ALPN protos in client context */
-bool libssl_set_alpn_protos(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::vector<std::vector<uint8_t>>& protos);
+bool libssl_set_alpn_protos(SSL_CTX* ctx, const std::vector<std::vector<uint8_t>>& protos);
 
 std::string libssl_get_error_string();
 
index 0cac651ce2e049f19867eb1b7315ee13fe937d9b..01332bbb71ae61abb7afae544e63a3123bbf79dc 100644 (file)
@@ -99,7 +99,7 @@ public:
   }
 
   /* client-side connection */
-  OpenSSLTLSConnection(const std::string& hostname, int socket, const struct timeval& timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_hostname(hostname), d_timeout(timeout)
+  OpenSSLTLSConnection(const std::string& hostname, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
   {
     d_socket = socket;
 
@@ -487,7 +487,10 @@ private:
   static std::atomic_flag s_initTLSConnIndex;
 
   std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
+  /* server context */
   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
+  /* client context */
+  std::shared_ptr<SSL_CTX> d_tlsCtx;
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
   std::string d_hostname;
   struct timeval d_timeout;
@@ -500,7 +503,7 @@ class OpenSSLTLSIOCtx: public TLSCtx
 {
 public:
   /* server side context */
-  OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig)), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
+  OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
   {
     d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
 
@@ -535,7 +538,7 @@ public:
   }
 
   /* client side context */
-  OpenSSLTLSIOCtx(const TLSContextParameters& params): d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
+  OpenSSLTLSIOCtx(const TLSContextParameters& params)
   {
     int sslOptions =
       SSL_OP_NO_SSLv2 |
@@ -549,9 +552,9 @@ public:
     registerOpenSSLUser();
 
 #ifdef HAVE_TLS_CLIENT_METHOD
-    d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
+    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
 #else
-    d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
+    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
 #endif
     if (!d_tlsCtx) {
       ERR_print_errors_fp(stderr);
@@ -661,7 +664,7 @@ public:
 
   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override
   {
-    return std::make_unique<OpenSSLTLSConnection>(host, socket, timeout, d_tlsCtx.get());
+    return std::make_unique<OpenSSLTLSConnection>(host, socket, timeout, d_tlsCtx);
   }
 
   void rotateTicketsKey(time_t now) override
@@ -696,11 +699,11 @@ public:
   {
     if (d_feContext && d_feContext->d_tlsCtx) {
       d_alpnProtos = protos;
-      libssl_set_alpn_select_callback(d_feContext->d_tlsCtx, alpnServerSelectCallback, this);
+      libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
       return true;
     }
     if (d_tlsCtx) {
-      return libssl_set_alpn_protos(d_tlsCtx, protos);
+      return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
     }
     return false;
   }
@@ -708,7 +711,7 @@ public:
   bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
   {
     d_nextProtocolSelectCallback = cb;
-    libssl_set_npn_select_callback(d_tlsCtx, npnSelectCallback, this);
+    libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
     return true;
   }
 
@@ -757,8 +760,8 @@ private:
   }
 
   std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
-  std::shared_ptr<OpenSSLFrontendContext> d_feContext;
-  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
+  std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
+  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
   bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
 };
 
@@ -891,7 +894,7 @@ class GnuTLSConnection: public TLSConnection
 {
 public:
   /* server side connection */
-  GnuTLSConnection(int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
+  GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
   {
     unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
@@ -908,7 +911,7 @@ public:
     d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
     conn = nullptr;
 
-    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
+    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error setting certificate and key to TLS connection");
     }
 
@@ -931,7 +934,7 @@ public:
   }
 
   /* client-side connection */
-  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
+  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
   {
     unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
@@ -948,7 +951,7 @@ public:
     d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
     conn = nullptr;
 
-    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds);
+    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
     if (rc != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
     }
@@ -1404,9 +1407,10 @@ public:
   }
 
 private:
-  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
+  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
   std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
+  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
   std::string d_host;
   bool d_client{false};
   bool d_handshakeDone{false};
@@ -1416,7 +1420,7 @@ class GnuTLSIOCtx: public TLSCtx
 {
 public:
   /* server side context */
-  GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
+  GnuTLSIOCtx(TLSFrontend& fe): d_enableTickets(fe.d_tlsConfig.d_enableTickets)
   {
     int rc = 0;
     d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
@@ -1427,7 +1431,7 @@ public:
       throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
     }
 
-    d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
     creds = nullptr;
 
     for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
@@ -1472,7 +1476,7 @@ public:
   }
 
   /* client side context */
-  GnuTLSIOCtx(const TLSContextParameters& params): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
+  GnuTLSIOCtx(const TLSContextParameters& params): d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
   {
     int rc = 0;
 
@@ -1482,7 +1486,7 @@ public:
       throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
     }
 
-    d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
     creds = nullptr;
 
     if (params.d_validateCertificates) {
@@ -1524,16 +1528,49 @@ public:
       ticketsKey = *(d_ticketsKey.read_lock());
     }
 
-    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets);
+    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
     if (!d_protos.empty()) {
       connection->setALPNProtos(d_protos);
     }
     return connection;
   }
 
+  static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
+  {
+    static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
+    auto& entry = t_credentials[{validate, caStore}];
+    if (!entry) {
+      gnutls_certificate_credentials_t creds;
+      int rc = gnutls_certificate_allocate_credentials(&creds);
+      if (rc != GNUTLS_E_SUCCESS) {
+        throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
+      }
+
+      entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
+      creds = nullptr;
+
+      if (validate) {
+        if (caStore.empty()) {
+          rc = gnutls_certificate_set_x509_system_trust(entry.get());
+          if (rc < 0) {
+            throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
+          }
+        }
+        else {
+          rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
+          if (rc < 0) {
+            throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
+          }
+        }
+      }
+    }
+    return entry;
+  }
+
   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override
   {
-    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts);
+    auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
+    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
     if (!d_protos.empty()) {
       connection->setALPNProtos(d_protos);
     }
@@ -1594,7 +1631,9 @@ public:
   }
 
 private:
-  std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
+  /* client context parameters */
+  std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
+  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
   std::vector<std::vector<uint8_t>> d_protos;
   gnutls_priority_t d_priorityCache{nullptr};
   SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};