]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix the client TLS wrapper for GnuTLS
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 29 Apr 2021 13:58:06 +0000 (15:58 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 26 Aug 2021 14:30:27 +0000 (16:30 +0200)
We need to call gnutls_handshake() repeatedly until it succeeds, while
OpenSSL allows us to read and write transparently.

pdns/tcpiohandler.cc

index daa49a1a59593fa9ccc312e4d92513fc3d0937b0..71405e5b5bbe9fb5fea2a669d892abf2f3f6c74f 100644 (file)
@@ -211,6 +211,11 @@ public:
 
   IOState tryHandshake() override
   {
+    if (!d_feContext) {
+      /* we are a client, nothing to do */
+      return IOState::Done;
+    }
+
     int res = SSL_accept(d_conn.get());
     if (res == 1) {
       return IOState::Done;
@@ -224,6 +229,11 @@ public:
 
   void doHandshake() override
   {
+    if (!d_feContext) {
+      /* we are a client, nothing to do */
+      return;
+    }
+
     int res = 0;
     do {
       res = SSL_accept(d_conn.get());
@@ -706,7 +716,7 @@ public:
   }
 
   /* client-side connection */
-  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host)
+  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
   {
     unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
@@ -765,6 +775,7 @@ public:
     do {
       ret = gnutls_handshake(d_conn.get());
       if (ret == GNUTLS_E_SUCCESS) {
+        d_handshakeDone = true;
         return IOState::Done;
       }
       else if (ret == GNUTLS_E_AGAIN) {
@@ -829,7 +840,9 @@ public:
         throw std::runtime_error("Error accepting a new connection");
       }
     }
-    while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
+    while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
+
+    d_handshakeDone = true;
   }
 
   IOState tryHandshake() override
@@ -839,10 +852,12 @@ public:
     do {
       ret = gnutls_handshake(d_conn.get());
       if (ret == GNUTLS_E_SUCCESS) {
+        d_handshakeDone = true;
         return IOState::Done;
       }
       else if (ret == GNUTLS_E_AGAIN) {
-        return IOState::NeedRead;
+        int direction = gnutls_record_get_direction(d_conn.get());
+        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
       }
       else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
         throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
@@ -854,6 +869,13 @@ public:
 
   IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
   {
+    if (!d_handshakeDone) {
+      auto state = tryHandshake();
+      if (state != IOState::Done) {
+        return state;
+      }
+    }
+
     do {
       ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
       if (res == 0) {
@@ -878,6 +900,13 @@ public:
 
   IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
   {
+    if (!d_handshakeDone) {
+      auto state = tryHandshake();
+      if (state != IOState::Done) {
+        return state;
+      }
+    }
+
     do {
       ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
       if (res == 0) {
@@ -1044,6 +1073,8 @@ private:
   std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
   std::string d_host;
+  bool d_client{false};
+  bool d_handshakeDone{false};
 };
 
 class GnuTLSIOCtx: public TLSCtx