From: Thomas Markwalder Date: Fri, 16 Dec 2022 20:40:33 +0000 (-0500) Subject: [#2684] Fixed TcpStreamRequest::postBuffer X-Git-Tag: Kea-2.3.4~88 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e1f1c829f6afa4ded1d14e4c956b495c9079bc5e;p=thirdparty%2Fkea.git [#2684] Fixed TcpStreamRequest::postBuffer Reworked the function, it was pretty broken, and added a unit test for it. src/lib/tcp/tcp_stream_msg.cc TcpStreamRequest::postBuffer() - rewrote it to actually work. src/lib/tcp/tests/tcp_listener_unittests.cc TEST(TcpStreamRequst, postBufferTest) - new test --- diff --git a/src/lib/tcp/tcp_stream_msg.cc b/src/lib/tcp/tcp_stream_msg.cc index 112445d34c..cd269be509 100644 --- a/src/lib/tcp/tcp_stream_msg.cc +++ b/src/lib/tcp/tcp_stream_msg.cc @@ -23,10 +23,37 @@ TcpStreamRequest::needData() const { size_t TcpStreamRequest::postBuffer(const void* buf, const size_t nbytes) { - if (nbytes) { - const char* bufptr = static_cast(buf); - wire_data_.insert(wire_data_.end(), bufptr, bufptr + nbytes); - if (!expected_size_ && wire_data_.size() >= sizeof(uint16_t)) { + if (!nbytes) { + // Nothing to do. + return (0); + } + + const char* bufptr = static_cast(buf); + size_t bytes_left = nbytes; + size_t wire_size = wire_data_.size(); + size_t bytes_used = 0; + while (bytes_left) { + if (expected_size_) { + // We have the length, copy as much of what we still need as we can. + size_t need_bytes = expected_size_ - wire_size; + size_t copy_bytes = (need_bytes <= bytes_left ? need_bytes : bytes_left); + wire_data_.insert(wire_data_.end(), bufptr, bufptr + copy_bytes); + bytes_left -= copy_bytes; + bytes_used += copy_bytes; + break; + } + + // Otherwise we don't know the length yet. + while (wire_size < 2 && bytes_left) { + wire_data_.push_back(*bufptr); + ++bufptr; + --bytes_left; + ++bytes_used; + ++wire_size; + } + + // If we have enough to do it, calculate the expected length. + if (wire_size == 2 ) { const uint8_t* cp = static_cast(wire_data_.data()); uint16_t len = ((unsigned int)(cp[0])) << 8; len |= ((unsigned int)(cp[1])); @@ -34,7 +61,8 @@ TcpStreamRequest::postBuffer(const void* buf, const size_t nbytes) { } } - return (nbytes); + // Return how much we used. + return (bytes_used); } std::string diff --git a/src/lib/tcp/tests/tcp_listener_unittests.cc b/src/lib/tcp/tests/tcp_listener_unittests.cc index 3d11b1cc22..70009f6dff 100644 --- a/src/lib/tcp/tests/tcp_listener_unittests.cc +++ b/src/lib/tcp/tests/tcp_listener_unittests.cc @@ -445,4 +445,95 @@ TEST_F(TcpListenerTest, filterClientsTest) { io_service_.poll(); } +// Exercises TcpStreamRequest::postBuffer() through various +// data permutations. +TEST(TcpStreamRequst, postBufferTest) { + // Struct describing a test scenario. + struct Scenario { + const std::string desc_; + // List of input buffers to submit to post. + std::list> input_buffers_; + // List of expected "request" strings conveyed. + std::list expected_strings_; + }; + + std::list scenarios{ + { + "1. Two complete messages in their own buffers", + { + { 0x00, 0x04, 0x31, 0x32, 0x33, 0x34 }, + { 0x00, 0x03, 0x35, 0x36, 0x37 }, + }, + { "1234", "567" } + }, + { + "2. Three messages: first two are in the same buffer", + { + { 0x00, 0x04, 0x31, 0x32, 0x33, 0x34, 0x00, 0x02, 0x35, 0x36 }, + { 0x00, 0x03, 0x37, 0x38, 0x39 }, + }, + { "1234", "56", "789" } + }, + { + "3. One message across three buffers", + { + { 0x00, 0x09, 0x31, 0x32, 0x33 }, + { 0x34, 0x35, 0x36, 0x37 }, + { 0x38, 0x39 }, + }, + { "123456789" } + + }, + { + "4. One message, length and data split across buffers", + { + { 0x00 }, + { 0x09, 0x31, 0x32, 0x33 }, + { 0x34, 0x35, 0x36, 0x37 }, + { 0x38, 0x39 }, + }, + { "123456789" } + + }, + }; + + for (auto scenario : scenarios ) { + SCOPED_TRACE(scenario.desc_); + std::list requests; + TcpStreamRequestPtr request; + for (auto input_buf : scenario.input_buffers_) { + // Copy the input buffer. + std::vector buf = input_buf; + + // While there is data left to use, use it. + while (buf.size()) { + // If we need a new request make one. + if (!request) { + request.reset(new TcpStreamRequest()); + } + + size_t bytes_used = request->postBuffer(static_cast(buf.data()), buf.size()); + if (!request->needData()) { + // Request is complete, save it. + requests.push_back(request); + request.reset(); + } + + // Consume bytes used. + if (bytes_used) { + buf.erase(buf.begin(), buf.begin() + bytes_used); + } + } + } + + ASSERT_EQ(requests.size(), scenario.expected_strings_.size()); + auto exp_string = scenario.expected_strings_.begin(); + for (auto recvd_request : requests) { + ASSERT_NO_THROW(recvd_request->unpack()); + EXPECT_EQ(*exp_string++, recvd_request->getRequestString()); + } + } +} + + }