]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
TCPIOHandler: Add preliminary support for session resumption
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 29 Apr 2021 15:57:59 +0000 (17:57 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 26 Aug 2021 14:30:27 +0000 (16:30 +0200)
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index 118b0636389037dde2c3aa03ec1abb10c7760121..ef13a82e4e66d7016e643afbc35aa82aacb8eaf0 100644 (file)
@@ -237,6 +237,16 @@ public:
     return false;
   }
 
+  std::unique_ptr<TLSSession> getSession() const override
+  {
+    throw std::runtime_error("getSession() not implemented");
+  }
+
+  void setSession(std::unique_ptr<TLSSession>& session) override
+  {
+    throw std::runtime_error("setSession() not implemented");
+  }
+
   /* unused in that context, don't bother */
   void doHandshake() override
   {
index 71405e5b5bbe9fb5fea2a669d892abf2f3f6c74f..1125d63ab63587f27b4e27c7a371aa52f7b207e5 100644 (file)
@@ -47,6 +47,26 @@ public:
   std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
 };
 
+class OpenSSLSession : public TLSSession
+{
+public:
+  OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
+  {
+  }
+
+  virtual ~OpenSSLSession()
+  {
+  }
+
+  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
+  {
+    return std::move(d_sess);
+  }
+
+private:
+  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
+};
+
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
@@ -383,6 +403,26 @@ public:
     return false;
   }
 
+  std::unique_ptr<TLSSession> getSession() const override
+  {
+    return std::unique_ptr<TLSSession>(new OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>(SSL_get1_session(d_conn.get()), SSL_SESSION_free)));
+  }
+
+  void setSession(std::unique_ptr<TLSSession>& session) override
+  {
+    auto sess = dynamic_cast<OpenSSLSession*>(session.get());
+    if (!sess) {
+      throw std::runtime_error("Unable to convert OpenSSL session");
+    }
+
+    auto native = sess->getNative();
+    auto ret = SSL_set_session(d_conn.get(), native.get());
+    if (ret != 1) {
+      throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
+    }
+    native.release();
+  }
+
   static int s_tlsConnIndex;
 
 private:
@@ -672,6 +712,36 @@ private:
   gnutls_datum_t d_key{nullptr, 0};
 };
 
+class GnuTLSSession : public TLSSession
+{
+public:
+  GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
+  {
+    sess.data = nullptr;
+    sess.size = 0;
+  }
+
+  virtual ~GnuTLSSession()
+  {
+    if (d_sess.data != nullptr && d_sess.size > 0) {
+      safe_memory_release(d_sess.data, d_sess.size);
+    }
+    gnutls_free(d_sess.data);
+    d_sess.data = nullptr;
+  }
+
+  gnutls_datum_t getNative()
+  {
+    auto ret = d_sess;
+    d_sess.data = nullptr;
+    d_sess.size = 0;
+    return ret;
+  }
+
+private:
+  gnutls_datum_t d_sess{nullptr, 0};
+};
+
 class GnuTLSConnection: public TLSConnection
 {
 public:
@@ -1062,6 +1132,32 @@ public:
     return false;
   }
 
+  std::unique_ptr<TLSSession> getSession() const override
+  {
+    gnutls_datum_t sess{nullptr, 0};
+    auto ret = gnutls_session_get_data2(d_conn.get(), &sess);
+    if (ret != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
+    }
+
+    return std::unique_ptr<TLSSession>(new GnuTLSSession(sess));
+  }
+
+  void setSession(std::unique_ptr<TLSSession>& session) override
+  {
+    auto sess = dynamic_cast<GnuTLSSession*>(session.get());
+    if (!sess) {
+      throw std::runtime_error("Unable to convert GnuTLS session");
+    }
+
+    auto native = sess->getNative();
+    auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
+    if (ret != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
+    }
+    session.release();
+  }
+
   void close() override
   {
     if (d_conn) {
index c73d4d0999a80c0f9b4754f467e09cd6e4e66d43..7c78ef65f9cf590e84c087e1c106f160ef9e797c 100644 (file)
 
 enum class IOState : uint8_t { Done, NeedRead, NeedWrite };
 
+class TLSSession
+{
+public:
+  virtual ~TLSSession()
+  {
+  }
+};
+
 class TLSConnection
 {
 public:
@@ -26,6 +34,8 @@ public:
   virtual std::string getServerNameIndication() const = 0;
   virtual LibsslTLSVersion getTLSVersion() const = 0;
   virtual bool hasSessionBeenResumed() const = 0;
+  virtual std::unique_ptr<TLSSession> getSession() const = 0;
+  virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
   virtual void close() = 0;
 
   void setUnknownTicketKey()
@@ -465,6 +475,22 @@ public:
     return d_conn && d_conn->getUnknownTicketKey();
   }
 
+  void setTLSSession(std::unique_ptr<TLSSession>& session)
+  {
+    if (d_conn != nullptr) {
+      d_conn->setSession(session);
+    }
+  }
+
+  std::unique_ptr<TLSSession> getTLSSession()
+  {
+    if (!d_conn) {
+      throw std::runtime_error("Trying to get a TLS session from a non-TLS handler");
+    }
+
+    return d_conn->getSession();
+  }
+
 private:
   std::unique_ptr<TLSConnection> d_conn{nullptr};
   ComboAddress d_remote;
@@ -484,4 +510,3 @@ struct TLSContextParameters
 };
 
 std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
-