std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
};
+class OpenSSLSession : public TLSSession
+{
+public:
+ OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
+ {
+ }
+
+ virtual ~OpenSSLSession()
+ {
+ }
+
+ std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
+ {
+ return std::move(d_sess);
+ }
+
+private:
+ std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
+};
+
class OpenSSLTLSConnection: public TLSConnection
{
public:
return false;
}
+ std::unique_ptr<TLSSession> getSession() const override
+ {
+ return std::unique_ptr<TLSSession>(new OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>(SSL_get1_session(d_conn.get()), SSL_SESSION_free)));
+ }
+
+ void setSession(std::unique_ptr<TLSSession>& session) override
+ {
+ auto sess = dynamic_cast<OpenSSLSession*>(session.get());
+ if (!sess) {
+ throw std::runtime_error("Unable to convert OpenSSL session");
+ }
+
+ auto native = sess->getNative();
+ auto ret = SSL_set_session(d_conn.get(), native.get());
+ if (ret != 1) {
+ throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
+ }
+ native.release();
+ }
+
static int s_tlsConnIndex;
private:
gnutls_datum_t d_key{nullptr, 0};
};
+class GnuTLSSession : public TLSSession
+{
+public:
+ GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
+ {
+ sess.data = nullptr;
+ sess.size = 0;
+ }
+
+ virtual ~GnuTLSSession()
+ {
+ if (d_sess.data != nullptr && d_sess.size > 0) {
+ safe_memory_release(d_sess.data, d_sess.size);
+ }
+ gnutls_free(d_sess.data);
+ d_sess.data = nullptr;
+ }
+
+ gnutls_datum_t getNative()
+ {
+ auto ret = d_sess;
+ d_sess.data = nullptr;
+ d_sess.size = 0;
+ return ret;
+ }
+
+private:
+ gnutls_datum_t d_sess{nullptr, 0};
+};
+
class GnuTLSConnection: public TLSConnection
{
public:
return false;
}
+ std::unique_ptr<TLSSession> getSession() const override
+ {
+ gnutls_datum_t sess{nullptr, 0};
+ auto ret = gnutls_session_get_data2(d_conn.get(), &sess);
+ if (ret != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
+ }
+
+ return std::unique_ptr<TLSSession>(new GnuTLSSession(sess));
+ }
+
+ void setSession(std::unique_ptr<TLSSession>& session) override
+ {
+ auto sess = dynamic_cast<GnuTLSSession*>(session.get());
+ if (!sess) {
+ throw std::runtime_error("Unable to convert GnuTLS session");
+ }
+
+ auto native = sess->getNative();
+ auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
+ if (ret != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
+ }
+ session.release();
+ }
+
void close() override
{
if (d_conn) {
enum class IOState : uint8_t { Done, NeedRead, NeedWrite };
+class TLSSession
+{
+public:
+ virtual ~TLSSession()
+ {
+ }
+};
+
class TLSConnection
{
public:
virtual std::string getServerNameIndication() const = 0;
virtual LibsslTLSVersion getTLSVersion() const = 0;
virtual bool hasSessionBeenResumed() const = 0;
+ virtual std::unique_ptr<TLSSession> getSession() const = 0;
+ virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
virtual void close() = 0;
void setUnknownTicketKey()
return d_conn && d_conn->getUnknownTicketKey();
}
+ void setTLSSession(std::unique_ptr<TLSSession>& session)
+ {
+ if (d_conn != nullptr) {
+ d_conn->setSession(session);
+ }
+ }
+
+ std::unique_ptr<TLSSession> getTLSSession()
+ {
+ if (!d_conn) {
+ throw std::runtime_error("Trying to get a TLS session from a non-TLS handler");
+ }
+
+ return d_conn->getSession();
+ }
+
private:
std::unique_ptr<TLSConnection> d_conn{nullptr};
ComboAddress d_remote;
};
std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
-