From: Remi Gacogne Date: Mon, 12 Jul 2021 13:49:41 +0000 (+0200) Subject: tcpiohandler: Add support for partial reads X-Git-Tag: dnsdist-1.7.0-alpha1~23^2~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1342b949c027c89eb4a8e02f4ce09f65245fa7a9;p=thirdparty%2Fpdns.git tcpiohandler: Add support for partial reads --- diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index d5483edd3f..c3411a55b5 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -174,7 +174,7 @@ public: return step.nextState; } - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) override { auto step = getStep(); BOOST_REQUIRE_EQUAL(step.request, !d_client ? ExpectedStep::ExpectedRequest::readFromClient : ExpectedStep::ExpectedRequest::readFromBackend); @@ -255,7 +255,7 @@ public: { } - size_t read(void* buffer, size_t bufferSize, const struct timeval&readTimeout, const struct timeval& totalTimeout={0,0}) override + size_t read(void* buffer, size_t bufferSize, const struct timeval&readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) override { return 0; } diff --git a/pdns/misc.cc b/pdns/misc.cc index b39a36752e..af667bf668 100644 --- a/pdns/misc.cc +++ b/pdns/misc.cc @@ -110,7 +110,7 @@ size_t readn2(int fd, void* buffer, size_t len) return len; } -size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout) +size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout, bool allowIncomplete) { size_t pos = 0; struct timeval start{0,0}; @@ -123,6 +123,9 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& ssize_t got = read(fd, (char *)buffer + pos, len - pos); if (got > 0) { pos += (size_t) got; + if (allowIncomplete) { + break; + } } else if (got == 0) { throw runtime_error("EOF while reading message"); diff --git a/pdns/misc.hh b/pdns/misc.hh index 0372e8ffae..c4f5457577 100644 --- a/pdns/misc.hh +++ b/pdns/misc.hh @@ -149,7 +149,7 @@ vstringtok (Container &container, string const &in, size_t writen2(int fd, const void *buf, size_t count); inline size_t writen2(int fd, const std::string &s) { return writen2(fd, s.data(), s.size()); } size_t readn2(int fd, void* buffer, size_t len); -size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout={0,0}); +size_t readn2WithTimeout(int fd, void* buffer, size_t len, const struct timeval& idleTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false); size_t writen2WithTimeout(int fd, const void * buffer, size_t len, const struct timeval& timeout); void toLowerInPlace(string& str); diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index b5360fe9aa..4eb80627ab 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -302,7 +302,7 @@ public: return IOState::Done; } - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override { do { int res = SSL_read(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toRead - pos)); @@ -311,13 +311,16 @@ public: } else { pos += static_cast(res); + if (allowIncomplete) { + break; + } } } while (pos < toRead); return IOState::Done; } - size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout) override + size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override { size_t got = 0; struct timeval start = {0, 0}; @@ -333,6 +336,9 @@ public: } else { got += static_cast(res); + if (allowIncomplete) { + break; + } } if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) { @@ -1075,7 +1081,7 @@ public: return IOState::Done; } - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override { if (!d_handshakeDone) { /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us, @@ -1093,6 +1099,9 @@ public: } else if (res > 0) { pos += static_cast(res); + if (allowIncomplete) { + break; + } } else if (res < 0) { if (gnutls_error_is_fatal(res)) { @@ -1108,7 +1117,7 @@ public: return IOState::Done; } - size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout) override + size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override { size_t got = 0; struct timeval start{0,0}; @@ -1124,6 +1133,9 @@ public: } else if (res > 0) { got += static_cast(res); + if (allowIncomplete) { + break; + } } else if (res < 0) { if (gnutls_error_is_fatal(res)) { diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index d18a92b05a..e948b130ae 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -27,10 +27,10 @@ public: virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0; virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0; virtual IOState tryHandshake() = 0; - virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}) = 0; + virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) = 0; virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0; virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0; - virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) = 0; + virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0; virtual bool hasBufferedData() const = 0; virtual std::string getServerNameIndication() const = 0; virtual LibsslTLSVersion getTLSVersion() const = 0; @@ -325,12 +325,12 @@ public: return IOState::Done; } - size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}) + size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}, bool allowIncomplete=false) { if (d_conn) { - return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout); + return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete); } else { - return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout); + return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete); } } @@ -340,14 +340,14 @@ public: return Done when toRead bytes have been read, needRead or needWrite if the IO operation would block. */ - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) { if (buffer.size() < toRead || pos >= toRead) { throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos)); } if (d_conn) { - return d_conn->tryRead(buffer, pos, toRead); + return d_conn->tryRead(buffer, pos, toRead, allowIncomplete); } do { @@ -365,6 +365,9 @@ public: } pos += static_cast(res); + if (allowIncomplete) { + break; + } } while (pos < toRead);