From f14050fa295d1f8b656d8e1dfdab6b5c8414260b Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Thu, 29 Apr 2021 15:58:06 +0200 Subject: [PATCH] dnsdist: Fix the client TLS wrapper for GnuTLS We need to call gnutls_handshake() repeatedly until it succeeds, while OpenSSL allows us to read and write transparently. --- pdns/tcpiohandler.cc | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index daa49a1a59..71405e5b5b 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -211,6 +211,11 @@ public: IOState tryHandshake() override { + if (!d_feContext) { + /* we are a client, nothing to do */ + return IOState::Done; + } + int res = SSL_accept(d_conn.get()); if (res == 1) { return IOState::Done; @@ -224,6 +229,11 @@ public: void doHandshake() override { + if (!d_feContext) { + /* we are a client, nothing to do */ + return; + } + int res = 0; do { res = SSL_accept(d_conn.get()); @@ -706,7 +716,7 @@ public: } /* client-side connection */ - GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_host(host) + GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_host(host), d_client(true) { unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK; #ifdef GNUTLS_NO_SIGNAL @@ -765,6 +775,7 @@ public: do { ret = gnutls_handshake(d_conn.get()); if (ret == GNUTLS_E_SUCCESS) { + d_handshakeDone = true; return IOState::Done; } else if (ret == GNUTLS_E_AGAIN) { @@ -829,7 +840,9 @@ public: throw std::runtime_error("Error accepting a new connection"); } } - while (ret < 0 && ret == GNUTLS_E_INTERRUPTED); + while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED); + + d_handshakeDone = true; } IOState tryHandshake() override @@ -839,10 +852,12 @@ public: do { ret = gnutls_handshake(d_conn.get()); if (ret == GNUTLS_E_SUCCESS) { + d_handshakeDone = true; return IOState::Done; } else if (ret == GNUTLS_E_AGAIN) { - return IOState::NeedRead; + int direction = gnutls_record_get_direction(d_conn.get()); + return direction == 0 ? IOState::NeedRead : IOState::NeedWrite; } else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret))); @@ -854,6 +869,13 @@ public: IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override { + if (!d_handshakeDone) { + auto state = tryHandshake(); + if (state != IOState::Done) { + return state; + } + } + do { ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toWrite - pos); if (res == 0) { @@ -878,6 +900,13 @@ public: IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override { + if (!d_handshakeDone) { + auto state = tryHandshake(); + if (state != IOState::Done) { + return state; + } + } + do { ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toRead - pos); if (res == 0) { @@ -1044,6 +1073,8 @@ private: std::unique_ptr d_conn; std::shared_ptr d_ticketsKey; std::string d_host; + bool d_client{false}; + bool d_handshakeDone{false}; }; class GnuTLSIOCtx: public TLSCtx -- 2.47.2