]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[#2684] Fixed TcpStreamRequest::postBuffer
authorThomas Markwalder <tmark@isc.org>
Fri, 16 Dec 2022 20:40:33 +0000 (15:40 -0500)
committerThomas Markwalder <tmark@isc.org>
Fri, 16 Dec 2022 21:49:14 +0000 (16:49 -0500)
Reworked the function, it was pretty broken,
and added a unit test for it.

src/lib/tcp/tcp_stream_msg.cc
    TcpStreamRequest::postBuffer() - rewrote it to
    actually work.

src/lib/tcp/tests/tcp_listener_unittests.cc
    TEST(TcpStreamRequst, postBufferTest)  - new test

src/lib/tcp/tcp_stream_msg.cc
src/lib/tcp/tests/tcp_listener_unittests.cc

index 112445d34cdceafb92c2a3db5234bdd9e39c7310..cd269be509f5afab362ed7f3be15b87c07da3413 100644 (file)
@@ -23,10 +23,37 @@ TcpStreamRequest::needData() const {
 
 size_t
 TcpStreamRequest::postBuffer(const void* buf,  const size_t nbytes) {
-    if (nbytes) {
-        const char* bufptr = static_cast<const char*>(buf);
-        wire_data_.insert(wire_data_.end(), bufptr, bufptr + nbytes);
-        if (!expected_size_ && wire_data_.size() >= sizeof(uint16_t)) {
+    if (!nbytes) {
+        // Nothing to do.
+        return (0);
+    }
+
+    const char* bufptr = static_cast<const char*>(buf);
+    size_t bytes_left = nbytes;
+    size_t wire_size = wire_data_.size();
+    size_t bytes_used = 0;
+    while (bytes_left) {
+        if (expected_size_) {
+            // We have the length, copy as much of what we still need as we can.
+            size_t need_bytes = expected_size_ - wire_size;
+            size_t copy_bytes = (need_bytes <= bytes_left ? need_bytes : bytes_left);
+            wire_data_.insert(wire_data_.end(), bufptr, bufptr + copy_bytes);
+            bytes_left -= copy_bytes;
+            bytes_used += copy_bytes;
+            break;
+        }
+
+        // Otherwise we don't know the length yet.
+        while (wire_size < 2 && bytes_left) {
+            wire_data_.push_back(*bufptr);
+            ++bufptr;
+            --bytes_left;
+            ++bytes_used;
+            ++wire_size;
+        }
+
+        // If we have enough to do it, calculate the expected length.
+        if (wire_size == 2 ) {
             const uint8_t* cp = static_cast<const uint8_t*>(wire_data_.data());
             uint16_t len = ((unsigned int)(cp[0])) << 8;
             len |= ((unsigned int)(cp[1]));
@@ -34,7 +61,8 @@ TcpStreamRequest::postBuffer(const void* buf,  const size_t nbytes) {
         }
     }
 
-    return (nbytes);
+    // Return how much we used.
+    return (bytes_used);
 }
 
 std::string
index 3d11b1cc22067edcf03a75e10260d1e84ce75182..70009f6dff2be3ff231f76a77a734b33ebb6f187 100644 (file)
@@ -445,4 +445,95 @@ TEST_F(TcpListenerTest, filterClientsTest) {
     io_service_.poll();
 }
 
+// Exercises TcpStreamRequest::postBuffer() through various
+// data permutations.
+TEST(TcpStreamRequst, postBufferTest) {
+    // Struct describing a test scenario.
+    struct Scenario {
+        const std::string desc_;
+        // List of input buffers to submit to post.
+        std::list<std::vector<uint8_t>> input_buffers_;
+        // List of expected "request" strings conveyed.
+        std::list<std::string> expected_strings_;
+    };
+
+    std::list<Scenario> scenarios{
+    {
+        "1. Two complete messages in their own buffers",
+        {
+            { 0x00, 0x04, 0x31, 0x32, 0x33, 0x34 },
+            { 0x00, 0x03, 0x35, 0x36, 0x37 },
+        },
+        { "1234", "567" }
+    },
+    {
+        "2. Three messages: first two are in the same buffer",
+        {
+            { 0x00, 0x04, 0x31, 0x32, 0x33, 0x34, 0x00, 0x02, 0x35, 0x36 },
+            { 0x00, 0x03, 0x37, 0x38, 0x39 },
+        },
+        { "1234", "56", "789" }
+    },
+    {
+        "3. One message across three buffers",
+        {
+            { 0x00, 0x09, 0x31, 0x32, 0x33 },
+            { 0x34, 0x35, 0x36, 0x37 },
+            { 0x38, 0x39 },
+        },
+        { "123456789" }
+
+    },
+    {
+        "4. One message, length and data split across buffers",
+        {
+            { 0x00 },
+            { 0x09, 0x31, 0x32, 0x33 },
+            { 0x34, 0x35, 0x36, 0x37 },
+            { 0x38, 0x39 },
+        },
+        { "123456789" }
+
+    },
+    };
+
+    for (auto scenario : scenarios ) {
+        SCOPED_TRACE(scenario.desc_);
+        std::list<TcpStreamRequestPtr> requests;
+        TcpStreamRequestPtr request;
+        for (auto input_buf : scenario.input_buffers_) {
+            // Copy the input buffer.
+            std::vector<uint8_t> buf = input_buf;
+
+            // While there is data left to use, use it.
+            while (buf.size()) {
+                // If we need a new request make one.
+                if (!request) {
+                    request.reset(new TcpStreamRequest());
+                }
+
+                size_t bytes_used = request->postBuffer(static_cast<void*>(buf.data()), buf.size());
+                if (!request->needData()) {
+                    // Request is complete, save it.
+                    requests.push_back(request);
+                    request.reset();
+                }
+
+                // Consume bytes used.
+                if (bytes_used) {
+                    buf.erase(buf.begin(), buf.begin() + bytes_used);
+                }
+            }
+        }
+
+        ASSERT_EQ(requests.size(), scenario.expected_strings_.size());
+        auto exp_string = scenario.expected_strings_.begin();
+        for (auto recvd_request : requests) {
+            ASSERT_NO_THROW(recvd_request->unpack());
+            EXPECT_EQ(*exp_string++, recvd_request->getRequestString());
+        }
+    }
+}
+
+
 }