From: Remi Gacogne Date: Wed, 29 Jan 2020 10:33:01 +0000 (+0100) Subject: dnsdist: Use ref counting for the DoT TLS context X-Git-Tag: auth-4.3.0-beta2~56^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F8761%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Use ref counting for the DoT TLS context Otherwise we can end up with a DNS over TLS connection using a TLS Session Ticket Encryption Key, OCSP response or even `SSL_CTX` object after it was released following a reload of the TLS context (via `reloadAllCertificates()`, for example), triggering a use-after-free, possibly leading to a crash. --- diff --git a/pdns/dnsdistdist/tcpiohandler.cc b/pdns/dnsdistdist/tcpiohandler.cc index 535ec180ad..59017d8f81 100644 --- a/pdns/dnsdistdist/tcpiohandler.cc +++ b/pdns/dnsdistdist/tcpiohandler.cc @@ -18,10 +18,36 @@ #include "libssl.hh" +class OpenSSLFrontendContext +{ +public: + OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys) + { + registerOpenSSLUser(); + + d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses); + if (!d_tlsCtx) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort()); + } + } + + void cleanup() + { + d_tlsCtx.reset(); + + unregisterOpenSSLUser(); + } + + OpenSSLTLSTicketKeysRing d_ticketKeys; + std::map d_ocspResponses; + std::unique_ptr d_tlsCtx{nullptr, SSL_CTX_free}; +}; + class OpenSSLTLSConnection: public TLSConnection { public: - OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout) + OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr feContext): d_feContext(feContext), d_conn(std::unique_ptr(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) { d_socket = socket; @@ -247,6 +273,7 @@ public: private: static std::atomic_flag s_initTLSConnIndex; + std::shared_ptr d_feContext; std::unique_ptr d_conn; unsigned int d_timeout; }; @@ -257,29 +284,24 @@ int OpenSSLTLSConnection::s_tlsConnIndex = -1; class OpenSSLTLSIOCtx: public TLSCtx { public: - OpenSSLTLSIOCtx(TLSFrontend& fe): d_ticketKeys(fe.d_tlsConfig.d_numberOfTicketsKeys) + OpenSSLTLSIOCtx(TLSFrontend& fe) { - registerOpenSSLUser(); - d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; + d_feContext = std::make_shared(fe.d_addr, fe.d_tlsConfig); - d_tlsCtx = libssl_init_server_context(fe.d_tlsConfig, d_ocspResponses); - if (!d_tlsCtx) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Error creating TLS context on " + fe.d_addr.toStringWithPort()); - } + d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) { /* use our own ticket keys handler so we can rotate them */ - SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); - libssl_set_ticket_key_callback_data(d_tlsCtx.get(), this); + SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); + libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get()); } - if (!d_ocspResponses.empty()) { - SSL_CTX_set_tlsext_status_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); - SSL_CTX_set_tlsext_status_arg(d_tlsCtx.get(), &d_ocspResponses); + if (!d_feContext->d_ocspResponses.empty()) { + SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); + SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses); } - libssl_set_error_counters_callback(d_tlsCtx, &fe.d_tlsCounters); + libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); try { if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { @@ -294,16 +316,13 @@ public: } } - virtual ~OpenSSLTLSIOCtx() override + ~OpenSSLTLSIOCtx() override { - d_tlsCtx.reset(); - - unregisterOpenSSLUser(); } static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) { - OpenSSLTLSIOCtx* ctx = reinterpret_cast(libssl_get_ticket_key_callback_data(s)); + OpenSSLFrontendContext* ctx = reinterpret_cast(libssl_get_ticket_key_callback_data(s)); if (ctx == nullptr) { return -1; } @@ -339,12 +358,12 @@ public: { handleTicketsKeyRotation(now); - return std::unique_ptr(new OpenSSLTLSConnection(socket, timeout, d_tlsCtx.get())); + return std::unique_ptr(new OpenSSLTLSConnection(socket, timeout, d_feContext)); } void rotateTicketsKey(time_t now) override { - d_ticketKeys.rotateTicketsKey(now); + d_feContext->d_ticketKeys.rotateTicketsKey(now); if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; @@ -353,7 +372,7 @@ public: void loadTicketsKeys(const std::string& keyFile) override final { - d_ticketKeys.loadTicketsKeys(keyFile); + d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; @@ -362,13 +381,11 @@ public: size_t getTicketsKeysCount() override { - return d_ticketKeys.getKeysCount(); + return d_feContext->d_ticketKeys.getKeysCount(); } private: - OpenSSLTLSTicketKeysRing d_ticketKeys; - std::map d_ocspResponses; - std::unique_ptr d_tlsCtx{nullptr, SSL_CTX_free}; + std::shared_ptr d_feContext; }; #endif /* HAVE_LIBSSL */