]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Use ref counting for the DoT TLS context 8761/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 29 Jan 2020 10:33:01 +0000 (11:33 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 29 Jan 2020 10:33:01 +0000 (11:33 +0100)
Otherwise we can end up with a DNS over TLS connection using a
TLS Session Ticket Encryption Key, OCSP response or even `SSL_CTX`
object after it was released following a reload of the TLS context
(via `reloadAllCertificates()`, for example), triggering a
use-after-free, possibly leading to a crash.

pdns/dnsdistdist/tcpiohandler.cc

index 535ec180ad64f9d118fba5748727754787541fd3..59017d8f81cb268bf18bd5b483519d4688fa9ea3 100644 (file)
 
 #include "libssl.hh"
 
+class OpenSSLFrontendContext
+{
+public:
+  OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
+  {
+    registerOpenSSLUser();
+
+    d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses);
+    if (!d_tlsCtx) {
+      ERR_print_errors_fp(stderr);
+      throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
+    }
+  }
+
+  void cleanup()
+  {
+    d_tlsCtx.reset();
+
+    unregisterOpenSSLUser();
+  }
+
+  OpenSSLTLSTicketKeysRing d_ticketKeys;
+  std::map<int, std::string> d_ocspResponses;
+  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
+};
+
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
-  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
+  OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
   {
     d_socket = socket;
 
@@ -247,6 +273,7 @@ public:
 private:
   static std::atomic_flag s_initTLSConnIndex;
 
+  std::shared_ptr<OpenSSLFrontendContext> d_feContext;
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
   unsigned int d_timeout;
 };
@@ -257,29 +284,24 @@ int OpenSSLTLSConnection::s_tlsConnIndex = -1;
 class OpenSSLTLSIOCtx: public TLSCtx
 {
 public:
-  OpenSSLTLSIOCtx(TLSFrontend& fe): d_ticketKeys(fe.d_tlsConfig.d_numberOfTicketsKeys)
+  OpenSSLTLSIOCtx(TLSFrontend& fe)
   {
-    registerOpenSSLUser();
-    d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
+    d_feContext = std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig);
 
-    d_tlsCtx = libssl_init_server_context(fe.d_tlsConfig, d_ocspResponses);
-    if (!d_tlsCtx) {
-      ERR_print_errors_fp(stderr);
-      throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort());
-    }
+    d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
 
     if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
       /* use our own ticket keys handler so we can rotate them */
-      SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
-      libssl_set_ticket_key_callback_data(d_tlsCtx.get(), this);
+      SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+      libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
     }
 
-    if (!d_ocspResponses.empty()) {
-      SSL_CTX_set_tlsext_status_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
-      SSL_CTX_set_tlsext_status_arg(d_tlsCtx.get(), &d_ocspResponses);
+    if (!d_feContext->d_ocspResponses.empty()) {
+      SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
+      SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
     }
 
-    libssl_set_error_counters_callback(d_tlsCtx, &fe.d_tlsCounters);
+    libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
 
     try {
       if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
@@ -294,16 +316,13 @@ public:
     }
   }
 
-  virtual ~OpenSSLTLSIOCtx() override
+  ~OpenSSLTLSIOCtx() override
   {
-    d_tlsCtx.reset();
-
-    unregisterOpenSSLUser();
   }
 
   static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
   {
-    OpenSSLTLSIOCtx* ctx = reinterpret_cast<OpenSSLTLSIOCtx*>(libssl_get_ticket_key_callback_data(s));
+    OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
     if (ctx == nullptr) {
       return -1;
     }
@@ -339,12 +358,12 @@ public:
   {
     handleTicketsKeyRotation(now);
 
-    return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get()));
+    return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_feContext));
   }
 
   void rotateTicketsKey(time_t now) override
   {
-    d_ticketKeys.rotateTicketsKey(now);
+    d_feContext->d_ticketKeys.rotateTicketsKey(now);
 
     if (d_ticketsKeyRotationDelay > 0) {
       d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
@@ -353,7 +372,7 @@ public:
 
   void loadTicketsKeys(const std::string& keyFile) override final
   {
-    d_ticketKeys.loadTicketsKeys(keyFile);
+    d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
 
     if (d_ticketsKeyRotationDelay > 0) {
       d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
@@ -362,13 +381,11 @@ public:
 
   size_t getTicketsKeysCount() override
   {
-    return d_ticketKeys.getKeysCount();
+    return d_feContext->d_ticketKeys.getKeysCount();
   }
 
 private:
-  OpenSSLTLSTicketKeysRing d_ticketKeys;
-  std::map<int, std::string> d_ocspResponses;
-  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
+  std::shared_ptr<OpenSSLFrontendContext> d_feContext;
 };
 
 #endif /* HAVE_LIBSSL */