From: Remi Gacogne Date: Fri, 13 Sep 2024 13:57:38 +0000 (+0200) Subject: dnsdist: Always store the OpenSSLTLSIOCtx in the connection X-Git-Tag: rec-5.2.0-alpha1~78^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6aac1f054fe6253fe3b051d5c7e3f78a71547e4e;p=thirdparty%2Fpdns.git dnsdist: Always store the OpenSSLTLSIOCtx in the connection --- diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 81107cbc06..893c5fa174 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -86,11 +86,13 @@ private: std::unique_ptr d_sess; }; +class OpenSSLTLSIOCtx; + class OpenSSLTLSConnection: public TLSConnection { public: /* server side connection */ - OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) + OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr tlsCtx, std::unique_ptr&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout) { d_socket = socket; @@ -110,7 +112,7 @@ public: } /* client-side connection */ - OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout) + OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr tlsCtx, std::unique_ptr&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true) { d_socket = socket; @@ -297,7 +299,7 @@ public: IOState tryHandshake() override { - if (!d_feContext) { + if (isClient()) { /* In client mode, the handshake is initiated by the call to SSL_connect() done from connect()/tryConnect(). In blocking mode it does not return before the handshake has been finished, @@ -325,7 +327,7 @@ public: void doHandshake() override { - if (!d_feContext) { + if (isClient()) { /* we are a client, nothing to do, see the non-blocking version */ return; } @@ -346,7 +348,7 @@ public: IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override { - if (!d_feContext && !d_connected) { + if (isClient() && !d_connected) { if (d_ktls) { /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */ SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get())); @@ -565,6 +567,11 @@ public: d_ktls = true; } + bool isClient() const + { + return d_isClient; + } + static void generateConnectionIndexIfNeeded() { auto init = s_initTLSConnIndex.lock(); @@ -590,25 +597,38 @@ private: static LockGuarded s_initTLSConnIndex; static int s_tlsConnIndex; std::vector> d_tlsSessions; - /* server context */ - std::shared_ptr d_feContext; - /* client context */ - std::shared_ptr d_tlsCtx; + std::shared_ptr d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime std::unique_ptr d_conn; std::string d_hostname; struct timeval d_timeout; bool d_connected{false}; bool d_ktls{false}; + bool d_isClient{false}; }; LockGuarded OpenSSLTLSConnection::s_initTLSConnIndex{false}; int OpenSSLTLSConnection::s_tlsConnIndex{-1}; -class OpenSSLTLSIOCtx: public TLSCtx +class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this { + struct Private + { + explicit Private() = default; + }; + public: + static std::shared_ptr createServerSideContext(TLSFrontend& fe) + { + return std::make_shared(fe, Private()); + } + + static std::shared_ptr createClientSideContext(const TLSContextParameters& params) + { + return std::make_shared(params, Private()); + } + /* server side context */ - OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared(fe.d_addr, fe.d_tlsConfig)) + OpenSSLTLSIOCtx(TLSFrontend& fe, [[maybe_unused]] Private priv): d_feContext(std::make_unique(fe.d_addr, fe.d_tlsConfig)) { OpenSSLTLSConnection::generateConnectionIndexIfNeeded(); @@ -655,7 +675,7 @@ public: } /* client side context */ - OpenSSLTLSIOCtx(const TLSContextParameters& params) + OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private) { int sslOptions = SSL_OP_NO_SSLv2 | @@ -803,16 +823,24 @@ public: return 1; } + SSL_CTX* getOpenSSLContext() const + { + if (d_feContext) { + return d_feContext->d_tlsCtx.get(); + } + return d_tlsCtx.get(); + } + std::unique_ptr getConnection(int socket, const struct timeval& timeout, time_t now) override { handleTicketsKeyRotation(now); - return std::make_unique(socket, timeout, d_feContext); + return std::make_unique(socket, timeout, shared_from_this(), std::unique_ptr(SSL_new(getOpenSSLContext()), SSL_free)); } std::unique_ptr getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override { - auto conn = std::make_unique(host, hostIsAddr, socket, timeout, d_tlsCtx); + auto conn = std::make_unique(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr(SSL_new(getOpenSSLContext()), SSL_free)); if (d_ktls) { conn->enableKTLS(); } @@ -847,24 +875,32 @@ public: return "openssl"; } + bool isServerContext() const + { + return d_feContext != nullptr; + } + bool setALPNProtos(const std::vector>& protos) override { - if (d_feContext && d_feContext->d_tlsCtx) { + auto* openSSLContext = getOpenSSLContext(); + if (openSSLContext == nullptr) { + return false; + } + + if (isServerContext()) { d_alpnProtos = protos; - libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this); + libssl_set_alpn_select_callback(openSSLContext, alpnServerSelectCallback, this); return true; } - if (d_tlsCtx) { - return libssl_set_alpn_protos(d_tlsCtx.get(), protos); - } - return false; + + return libssl_set_alpn_protos(openSSLContext, protos); } #ifndef DISABLE_NPN bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override { d_nextProtocolSelectCallback = cb; - libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this); + libssl_set_npn_select_callback(getOpenSSLContext(), npnSelectCallback, this); return true; } #endif /* DISABLE_NPN */ @@ -919,8 +955,8 @@ private: } std::vector> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent - std::shared_ptr d_feContext{nullptr}; std::shared_ptr d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx + std::unique_ptr d_feContext{nullptr}; bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr}; bool d_ktls{false}; }; @@ -1870,13 +1906,13 @@ bool TLSFrontend::setupTLS() #endif /* HAVE_GNUTLS */ #if defined(HAVE_LIBSSL) if (d_provider == "openssl") { - newCtx = std::make_shared(*this); + newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this); } #endif /* HAVE_LIBSSL */ if (!newCtx) { #if defined(HAVE_LIBSSL) - newCtx = std::make_shared(*this); + newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this); #elif defined(HAVE_GNUTLS) newCtx = std::make_shared(*this); #else @@ -1908,13 +1944,13 @@ std::shared_ptr getTLSContext([[maybe_unused]] const TLSContextParameter #endif /* HAVE_GNUTLS */ #if defined(HAVE_LIBSSL) if (params.d_provider == "openssl") { - return std::make_shared(params); + return OpenSSLTLSIOCtx::createClientSideContext(params); } #endif /* HAVE_LIBSSL */ } #if defined(HAVE_LIBSSL) - return std::make_shared(params); + return OpenSSLTLSIOCtx::createClientSideContext(params); #elif defined(HAVE_GNUTLS) return std::make_shared(params); #else