From: Remi Gacogne Date: Thu, 29 Apr 2021 15:57:59 +0000 (+0200) Subject: TCPIOHandler: Add preliminary support for session resumption X-Git-Tag: dnsdist-1.7.0-alpha1~45^2~30 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9a311786b43d7fdf01bfe9c34942530bc801fccd;p=thirdparty%2Fpdns.git TCPIOHandler: Add preliminary support for session resumption --- diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 118b063638..ef13a82e4e 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -237,6 +237,16 @@ public: return false; } + std::unique_ptr getSession() const override + { + throw std::runtime_error("getSession() not implemented"); + } + + void setSession(std::unique_ptr& session) override + { + throw std::runtime_error("setSession() not implemented"); + } + /* unused in that context, don't bother */ void doHandshake() override { diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 71405e5b5b..1125d63ab6 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -47,6 +47,26 @@ public: std::unique_ptr d_keyLogFile{nullptr, fclose}; }; +class OpenSSLSession : public TLSSession +{ +public: + OpenSSLSession(std::unique_ptr&& sess): d_sess(std::move(sess)) + { + } + + virtual ~OpenSSLSession() + { + } + + std::unique_ptr getNative() + { + return std::move(d_sess); + } + +private: + std::unique_ptr d_sess; +}; + class OpenSSLTLSConnection: public TLSConnection { public: @@ -383,6 +403,26 @@ public: return false; } + std::unique_ptr getSession() const override + { + return std::unique_ptr(new OpenSSLSession(std::unique_ptr(SSL_get1_session(d_conn.get()), SSL_SESSION_free))); + } + + void setSession(std::unique_ptr& session) override + { + auto sess = dynamic_cast(session.get()); + if (!sess) { + throw std::runtime_error("Unable to convert OpenSSL session"); + } + + auto native = sess->getNative(); + auto ret = SSL_set_session(d_conn.get(), native.get()); + if (ret != 1) { + throw std::runtime_error("Error setting up session: " + libssl_get_error_string()); + } + native.release(); + } + static int s_tlsConnIndex; private: @@ -672,6 +712,36 @@ private: gnutls_datum_t d_key{nullptr, 0}; }; +class GnuTLSSession : public TLSSession +{ +public: + GnuTLSSession(gnutls_datum_t& sess): d_sess(sess) + { + sess.data = nullptr; + sess.size = 0; + } + + virtual ~GnuTLSSession() + { + if (d_sess.data != nullptr && d_sess.size > 0) { + safe_memory_release(d_sess.data, d_sess.size); + } + gnutls_free(d_sess.data); + d_sess.data = nullptr; + } + + gnutls_datum_t getNative() + { + auto ret = d_sess; + d_sess.data = nullptr; + d_sess.size = 0; + return ret; + } + +private: + gnutls_datum_t d_sess{nullptr, 0}; +}; + class GnuTLSConnection: public TLSConnection { public: @@ -1062,6 +1132,32 @@ public: return false; } + std::unique_ptr getSession() const override + { + gnutls_datum_t sess{nullptr, 0}; + auto ret = gnutls_session_get_data2(d_conn.get(), &sess); + if (ret != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret))); + } + + return std::unique_ptr(new GnuTLSSession(sess)); + } + + void setSession(std::unique_ptr& session) override + { + auto sess = dynamic_cast(session.get()); + if (!sess) { + throw std::runtime_error("Unable to convert GnuTLS session"); + } + + auto native = sess->getNative(); + auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size); + if (ret != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret))); + } + session.release(); + } + void close() override { if (d_conn) { diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index c73d4d0999..7c78ef65f9 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -10,6 +10,14 @@ enum class IOState : uint8_t { Done, NeedRead, NeedWrite }; +class TLSSession +{ +public: + virtual ~TLSSession() + { + } +}; + class TLSConnection { public: @@ -26,6 +34,8 @@ public: virtual std::string getServerNameIndication() const = 0; virtual LibsslTLSVersion getTLSVersion() const = 0; virtual bool hasSessionBeenResumed() const = 0; + virtual std::unique_ptr getSession() const = 0; + virtual void setSession(std::unique_ptr& session) = 0; virtual void close() = 0; void setUnknownTicketKey() @@ -465,6 +475,22 @@ public: return d_conn && d_conn->getUnknownTicketKey(); } + void setTLSSession(std::unique_ptr& session) + { + if (d_conn != nullptr) { + d_conn->setSession(session); + } + } + + std::unique_ptr getTLSSession() + { + if (!d_conn) { + throw std::runtime_error("Trying to get a TLS session from a non-TLS handler"); + } + + return d_conn->getSession(); + } + private: std::unique_ptr d_conn{nullptr}; ComboAddress d_remote; @@ -484,4 +510,3 @@ struct TLSContextParameters }; std::shared_ptr getTLSContext(const TLSContextParameters& params); -