]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
tcpiohandler: Add support for partial reads
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 12 Jul 2021 13:49:41 +0000 (15:49 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 13 Sep 2021 13:28:27 +0000 (15:28 +0200)
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/misc.cc
pdns/misc.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index d5483edd3f319eb71533f125bd307dd49a1ec6f3..c3411a55b51c9f84af2d8e0aa2a436705fbb0839 100644 (file)
@@ -174,7 +174,7 @@ public:
     return step.nextState;
   }
 
-  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) override
   {
     auto step = getStep();
     BOOST_REQUIRE_EQUAL(step.request, !d_client ? ExpectedStep::ExpectedRequest::readFromClient : ExpectedStep::ExpectedRequest::readFromBackend);
@@ -255,7 +255,7 @@ public:
   {
   }
 
-  size_t read(void* buffer, size_t bufferSize, const struct timeval&readTimeout, const struct timeval& totalTimeout={0,0}) override
+  size_t read(void* buffer, size_t bufferSize, const struct timeval&readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) override
   {
     return 0;
   }
index b39a36752e7445bec1e6f17208f6a2162aa66190..af667bf668b2439a4e9e5e0a138c5f72c9c2837f 100644 (file)
@@ -110,7 +110,7 @@ size_t readn2(int fd, void* buffer, size_t len)
   return len;
 }
 
-size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout)
+size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout, bool allowIncomplete)
 {
   size_t pos = 0;
   struct timeval start{0,0};
@@ -123,6 +123,9 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval&
     ssize_t got = read(fd, (char *)buffer + pos, len - pos);
     if (got > 0) {
       pos += (size_t) got;
+      if (allowIncomplete) {
+        break;
+      }
     }
     else if (got == 0) {
       throw runtime_error("EOF while reading message");
index 0372e8ffae41b38aed6789426bd15803818b9c6b..c4f54575773febdb724b7daf28a09b040ac0ab74 100644 (file)
@@ -149,7 +149,7 @@ vstringtok (Container &container, string const &in,
 size_t writen2(int fd, const void *buf, size_t count);
 inline size_t writen2(int fd, const std::string &s) { return writen2(fd, s.data(), s.size()); }
 size_t readn2(int fd, void* buffer, size_t len);
-size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout={0,0});
+size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false);
 size_t writen2WithTimeout(int fd, const void * buffer, size_t len, const struct timeval& timeout);
 
 void toLowerInPlace(string& str);
index b5360fe9aa37a6e2befebeee4ee22b225c735adf..4eb80627ab54fa3ea80f1790d6d97bf3549ec484 100644 (file)
@@ -302,7 +302,7 @@ public:
     return IOState::Done;
   }
 
-  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
   {
     do {
       int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
@@ -311,13 +311,16 @@ public:
       }
       else {
         pos += static_cast<size_t>(res);
+        if (allowIncomplete) {
+          break;
+        }
       }
     }
     while (pos < toRead);
     return IOState::Done;
   }
 
-  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout) override
+  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
   {
     size_t got = 0;
     struct timeval start = {0, 0};
@@ -333,6 +336,9 @@ public:
       }
       else {
         got += static_cast<size_t>(res);
+        if (allowIncomplete) {
+          break;
+        }
       }
 
       if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
@@ -1075,7 +1081,7 @@ public:
     return IOState::Done;
   }
 
-  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
   {
     if (!d_handshakeDone) {
       /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
@@ -1093,6 +1099,9 @@ public:
       }
       else if (res > 0) {
         pos += static_cast<size_t>(res);
+        if (allowIncomplete) {
+          break;
+        }
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
@@ -1108,7 +1117,7 @@ public:
     return IOState::Done;
   }
 
-  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout) override
+  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
   {
     size_t got = 0;
     struct timeval start{0,0};
@@ -1124,6 +1133,9 @@ public:
       }
       else if (res > 0) {
         got += static_cast<size_t>(res);
+        if (allowIncomplete) {
+          break;
+        }
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
index d18a92b05af22d30934d681b9c99d996f7094f94..e948b130aea5f2da27caea7c5ca5dd0c9039043b 100644 (file)
@@ -27,10 +27,10 @@ public:
   virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
   virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0;
   virtual IOState tryHandshake() = 0;
-  virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}) = 0;
+  virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) = 0;
   virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0;
   virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
-  virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) = 0;
+  virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
   virtual bool hasBufferedData() const = 0;
   virtual std::string getServerNameIndication() const = 0;
   virtual LibsslTLSVersion getTLSVersion() const = 0;
@@ -325,12 +325,12 @@ public:
     return IOState::Done;
   }
 
-  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0})
+  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}, bool allowIncomplete=false)
   {
     if (d_conn) {
-      return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout);
+      return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
     } else {
-      return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
+      return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
     }
   }
 
@@ -340,14 +340,14 @@ public:
      return Done when toRead bytes have been read, needRead or needWrite if the IO operation
      would block.
   */
-  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead)
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false)
   {
     if (buffer.size() < toRead || pos >= toRead) {
       throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
     }
 
     if (d_conn) {
-      return d_conn->tryRead(buffer, pos, toRead);
+      return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
     }
 
     do {
@@ -365,6 +365,9 @@ public:
       }
 
       pos += static_cast<size_t>(res);
+      if (allowIncomplete) {
+        break;
+      }
     }
     while (pos < toRead);