]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Use callbacks to retrieve TLS tickets sent by the server
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 11 Jun 2021 10:25:34 +0000 (12:25 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 26 Aug 2021 14:30:28 +0000 (16:30 +0200)
In TLS 1.3, tickets can be sent at any moment of the TLS session,
and more importantly are not guaranteed to be sent before a few
bytes have been exchanged. In addition, GnuTLS invalidates a session
if the remote closes the connection in a unexpected way (which Python
seems to do, for example) so we can't rely on the ticket being available
at the end of the exchange either.
We now instead use callbacks so we can be notified as soon as a new
ticket arrives, and deal with it. We store inside the TLS connection
object so we can retrieve it at the end of the exchange, when
deciding whether the whole TCP connection can be reused or if we want
to tear it down and store the ticket for later resumption instead.

pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index 41a5b4731875f8374169e358576630741007dabb..99945ffaf3f199c9bf585bd8c69896d12de664d6 100644 (file)
@@ -16,9 +16,14 @@ TCPConnectionToBackend::~TCPConnectionToBackend()
       if (d_handler->hasTLSSessionBeenResumed()) {
         ++d_ds->tlsResumptions;
       }
-      auto session = d_handler->getTLSSession();
-      if (session) {
-        g_sessionCache.putSession(d_ds->getID(), now.tv_sec, std::move(session));
+      try {
+        auto session = d_handler->getTLSSession();
+        if (session) {
+          g_sessionCache.putSession(d_ds->getID(), now.tv_sec, std::move(session));
+        }
+      }
+      catch (const std::exception& e) {
+        vinfolog("Unable to get a TLS session: %s", e.what());
       }
     }
     auto diff = now - d_connectionStartTime;
@@ -161,6 +166,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
          Let's just drop the connection
       */
       vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_state == State::sendingQueryToBackend ? "writing to" : "reading from"), conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+
       if (conn->d_state == State::sendingQueryToBackend) {
         ++conn->d_ds->tcpDiedSendingQuery;
       }
@@ -333,7 +339,12 @@ bool TCPConnectionToBackend::reconnect()
       if (d_handler->hasTLSSessionBeenResumed()) {
         ++d_ds->tlsResumptions;
       }
-      tlsSession = d_handler->getTLSSession();
+      try {
+        tlsSession = d_handler->getTLSSession();
+      }
+      catch (const std::exception& e) {
+        vinfolog("Unable to get a TLS session to resume: %s", e.what());
+      }
     }
     d_handler->close();
     d_ioState.reset();
index 5c2321c4505b51c2cb448c30f0d53061b4ad89ce..17ab57ad7f6b98654675f8204e6b1370483536c2 100644 (file)
@@ -237,7 +237,7 @@ public:
     return false;
   }
 
-  std::unique_ptr<TLSSession> getSession() const override
+  std::unique_ptr<TLSSession> getSession() override
   {
     return nullptr;
   }
index ae8fbbbb5092e0d808f5365133a09be3e5d063d4..7efb0a410ba3d160bf544f95ecab1565ac0db924 100644 (file)
@@ -103,6 +103,14 @@ public:
   {
     d_socket = socket;
 
+    if (!s_initTLSConnIndex.test_and_set()) {
+      /* not initialized yet */
+      s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+      if (s_tlsConnIndex == -1) {
+        throw std::runtime_error("Error getting an index for TLS connection data");
+      }
+    }
+
     if (!d_conn) {
       vinfolog("Error creating TLS object");
       if (g_verbose) {
@@ -130,6 +138,7 @@ public:
 #else
     /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
 #endif
+    SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
   }
 
   IOState convertIORequestToIOState(int res) const
@@ -403,9 +412,13 @@ public:
     return false;
   }
 
-  std::unique_ptr<TLSSession> getSession() const override
+  std::unique_ptr<TLSSession> getSession() override
   {
-    return std::unique_ptr<TLSSession>(new OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>(SSL_get1_session(d_conn.get()), SSL_SESSION_free)));
+    if (d_tlsSession) {
+      return std::move(d_tlsSession);
+    }
+
+    throw std::runtime_error("Unable to get an OpenSSL session");
   }
 
   void setSession(std::unique_ptr<TLSSession>& session) override
@@ -423,6 +436,11 @@ public:
     native.release();
   }
 
+  void setNewTicket(SSL_SESSION* session)
+  {
+    d_tlsSession = std::unique_ptr<TLSSession>(new OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>(session, SSL_SESSION_free)));
+  }
+
   static int s_tlsConnIndex;
 
 private:
@@ -430,6 +448,7 @@ private:
 
   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
+  std::unique_ptr<TLSSession> d_tlsSession{nullptr};
   std::string d_hostname;
   struct timeval d_timeout;
 };
@@ -535,6 +554,11 @@ public:
       warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
 #endif
     }
+
+    /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
+       but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
+    SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE);
+    SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb);
   }
 
   ~OpenSSLTLSIOCtx() override
@@ -577,6 +601,17 @@ public:
     return libssl_ocsp_stapling_callback(ssl, *ocspMap);
   }
 
+  static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session)
+  {
+    OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(ssl, OpenSSLTLSConnection::s_tlsConnIndex));
+    if (session == nullptr || conn == nullptr) {
+      return 0;
+    }
+
+    conn->setNewTicket(session);
+    return 1;
+  }
+
   std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
   {
     handleTicketsKeyRotation(now);
@@ -746,7 +781,7 @@ class GnuTLSConnection: public TLSConnection
 {
 public:
   /* server side connection */
-  GnuTLSConnection(int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
+  GnuTLSConnection(int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
   {
     unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
@@ -830,6 +865,31 @@ public:
 #else
     /* no hostname validation for you */
 #endif
+
+    /* allow access to our data in the callbacks */
+    gnutls_session_set_ptr(d_conn.get(), this);
+    gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb);
+  }
+
+  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming, const gnutls_datum_t* msg)
+  {
+    if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
+      return 0;
+    }
+
+    GnuTLSConnection* conn = reinterpret_cast<GnuTLSConnection*>(gnutls_session_get_ptr(session));
+    if (conn == nullptr) {
+      return 0;
+    }
+
+    gnutls_datum_t sess{nullptr, 0};
+    auto ret = gnutls_session_get_data2(session, &sess);
+    /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
+    if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) {
+      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
+    }
+    conn->d_tlsSession = std::unique_ptr<TLSSession>(new GnuTLSSession(sess));
+    return 0;
   }
 
   IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
@@ -1157,28 +1217,13 @@ public:
     return false;
   }
 
-  std::unique_ptr<TLSSession> getSession() const override
+  std::unique_ptr<TLSSession> getSession() override
   {
-    if (getTLSVersion() == LibsslTLSVersion::TLS13) {
-#if GNUTLS_VERSION_NUMBER >= 0x030603
-      /* with TLS 1.3, gnutls_session_get_data2() will _wait_ for a ticket is there is none yet.. */
-      if ((gnutls_session_get_flags(d_conn.get()) & GNUTLS_SFLAGS_SESSION_TICKET) == 0) {
-        return nullptr;
-      }
-#else /* GNUTLS_VERSION_NUMBER >= 0x030603 */
-      /* the GNUTLS_SFLAGS_SESSION_TICKET flag does not exist before 3.6.3 (but TLS 1.3 should not either), so we can't be sure we are not
-         going to block, better give up. */
-      return nullptr;
-#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
-    }
-
-    gnutls_datum_t sess{nullptr, 0};
-    auto ret = gnutls_session_get_data2(d_conn.get(), &sess);
-    if (ret != GNUTLS_E_SUCCESS) {
-      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
+    if (d_tlsSession) {
+      return std::move(d_tlsSession);
     }
 
-    return std::unique_ptr<TLSSession>(new GnuTLSSession(sess));
+    throw std::runtime_error("No GnuTLSSession available yet");
   }
 
   void setSession(std::unique_ptr<TLSSession>& session) override
@@ -1193,6 +1238,7 @@ public:
     if (ret != GNUTLS_E_SUCCESS) {
       throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
     }
+
     session.release();
   }
 
@@ -1204,8 +1250,9 @@ public:
   }
 
 private:
-  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
+  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
+  std::unique_ptr<TLSSession> d_tlsSession{nullptr};
   std::string d_host;
   bool d_client{false};
   bool d_handshakeDone{false};
index 70260c6c722c10f11d86d056f671856d51ceedc3..8057eb70b22405826f4cdbad4e9f20e27bd9cebf 100644 (file)
@@ -35,7 +35,7 @@ public:
   virtual std::string getServerNameIndication() const = 0;
   virtual LibsslTLSVersion getTLSVersion() const = 0;
   virtual bool hasSessionBeenResumed() const = 0;
-  virtual std::unique_ptr<TLSSession> getSession() const = 0;
+  virtual std::unique_ptr<TLSSession> getSession() = 0;
   virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
   virtual void close() = 0;