]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Always store the OpenSSLTLSIOCtx in the connection
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 13 Sep 2024 13:57:38 +0000 (15:57 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 16 Sep 2024 10:30:17 +0000 (12:30 +0200)
pdns/tcpiohandler.cc

index 81107cbc068e3a03607228ce25bd506290b5d47d..893c5fa174a88079d4cd9728114422d451f94d97 100644 (file)
@@ -86,11 +86,13 @@ private:
   std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
 };
 
+class OpenSSLTLSIOCtx;
+
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
   /* server side connection */
-  OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
+  OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout)
   {
     d_socket = socket;
 
@@ -110,7 +112,7 @@ public:
   }
 
   /* client-side connection */
-  OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
+  OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true)
   {
     d_socket = socket;
 
@@ -297,7 +299,7 @@ public:
 
   IOState tryHandshake() override
   {
-    if (!d_feContext) {
+    if (isClient()) {
       /* In client mode, the handshake is initiated by the call to SSL_connect()
          done from connect()/tryConnect().
          In blocking mode it does not return before the handshake has been finished,
@@ -325,7 +327,7 @@ public:
 
   void doHandshake() override
   {
-    if (!d_feContext) {
+    if (isClient()) {
       /* we are a client, nothing to do, see the non-blocking version */
       return;
     }
@@ -346,7 +348,7 @@ public:
 
   IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
   {
-    if (!d_feContext && !d_connected) {
+    if (isClient() && !d_connected) {
       if (d_ktls) {
         /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
         SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
@@ -565,6 +567,11 @@ public:
     d_ktls = true;
   }
 
+  bool isClient() const
+  {
+    return d_isClient;
+  }
+
   static void generateConnectionIndexIfNeeded()
   {
     auto init = s_initTLSConnIndex.lock();
@@ -590,25 +597,38 @@ private:
   static LockGuarded<bool> s_initTLSConnIndex;
   static int s_tlsConnIndex;
   std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
-  /* server context */
-  std::shared_ptr<OpenSSLFrontendContext> d_feContext;
-  /* client context */
-  std::shared_ptr<SSL_CTX> d_tlsCtx;
+  std::shared_ptr<const OpenSSLTLSIOCtx> d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
   std::string d_hostname;
   struct timeval d_timeout;
   bool d_connected{false};
   bool d_ktls{false};
+  bool d_isClient{false};
 };
 
 LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
 int OpenSSLTLSConnection::s_tlsConnIndex{-1};
 
-class OpenSSLTLSIOCtx: public TLSCtx
+class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this<OpenSSLTLSIOCtx>
 {
+  struct Private
+  {
+    explicit Private() = default;
+  };
+
 public:
+  static std::shared_ptr<OpenSSLTLSIOCtx> createServerSideContext(TLSFrontend& fe)
+  {
+    return std::make_shared<OpenSSLTLSIOCtx>(fe, Private());
+  }
+
+  static std::shared_ptr<OpenSSLTLSIOCtx> createClientSideContext(const TLSContextParameters& params)
+  {
+    return std::make_shared<OpenSSLTLSIOCtx>(params, Private());
+  }
+
   /* server side context */
-  OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
+  OpenSSLTLSIOCtx(TLSFrontend& fe, [[maybe_unused]] Private priv): d_feContext(std::make_unique<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
   {
     OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
 
@@ -655,7 +675,7 @@ public:
   }
 
   /* client side context */
-  OpenSSLTLSIOCtx(const TLSContextParameters& params)
+  OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private)
   {
     int sslOptions =
       SSL_OP_NO_SSLv2 |
@@ -803,16 +823,24 @@ public:
     return 1;
   }
 
+  SSL_CTX* getOpenSSLContext() const
+  {
+    if (d_feContext) {
+      return d_feContext->d_tlsCtx.get();
+    }
+    return d_tlsCtx.get();
+  }
+
   std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
   {
     handleTicketsKeyRotation(now);
 
-    return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
+    return std::make_unique<OpenSSLTLSConnection>(socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
   }
 
   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
   {
-    auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, d_tlsCtx);
+    auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
     if (d_ktls) {
       conn->enableKTLS();
     }
@@ -847,24 +875,32 @@ public:
     return "openssl";
   }
 
+  bool isServerContext() const
+  {
+    return d_feContext != nullptr;
+  }
+
   bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
   {
-    if (d_feContext && d_feContext->d_tlsCtx) {
+    auto* openSSLContext = getOpenSSLContext();
+    if (openSSLContext == nullptr) {
+      return false;
+    }
+
+    if (isServerContext()) {
       d_alpnProtos = protos;
-      libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
+      libssl_set_alpn_select_callback(openSSLContext, alpnServerSelectCallback, this);
       return true;
     }
-    if (d_tlsCtx) {
-      return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
-    }
-    return false;
+
+    return libssl_set_alpn_protos(openSSLContext, protos);
   }
 
 #ifndef DISABLE_NPN
   bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
   {
     d_nextProtocolSelectCallback = cb;
-    libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
+    libssl_set_npn_select_callback(getOpenSSLContext(), npnSelectCallback, this);
     return true;
   }
 #endif /* DISABLE_NPN */
@@ -919,8 +955,8 @@ private:
   }
 
   std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
-  std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
   std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
+  std::unique_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
   bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
   bool d_ktls{false};
 };
@@ -1870,13 +1906,13 @@ bool TLSFrontend::setupTLS()
 #endif /* HAVE_GNUTLS */
 #if defined(HAVE_LIBSSL)
   if (d_provider == "openssl") {
-    newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
   }
 #endif /* HAVE_LIBSSL */
 
   if (!newCtx) {
 #if defined(HAVE_LIBSSL)
-    newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
 #elif defined(HAVE_GNUTLS)
     newCtx = std::make_shared<GnuTLSIOCtx>(*this);
 #else
@@ -1908,13 +1944,13 @@ std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameter
 #endif /* HAVE_GNUTLS */
 #if defined(HAVE_LIBSSL)
     if (params.d_provider == "openssl") {
-      return std::make_shared<OpenSSLTLSIOCtx>(params);
+      return OpenSSLTLSIOCtx::createClientSideContext(params);
     }
 #endif /* HAVE_LIBSSL */
   }
 
 #if defined(HAVE_LIBSSL)
-  return std::make_shared<OpenSSLTLSIOCtx>(params);
+  return OpenSSLTLSIOCtx::createClientSideContext(params);
 #elif defined(HAVE_GNUTLS)
   return std::make_shared<GnuTLSIOCtx>(params);
 #else