]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[#2583] Added TcpConnction::responesSent()
authorThomas Markwalder <tmark@isc.org>
Thu, 10 Nov 2022 17:23:52 +0000 (12:23 -0500)
committerThomas Markwalder <tmark@isc.org>
Thu, 10 Nov 2022 19:43:23 +0000 (14:43 -0500)
src/lib/tcp/tcp_connection.*
    TcpConnection::responseSent() - new virtual method
    TcpConnection::doWrite() - calls responseSent() and
    conditionally starts idle timer

src/lib/tcp/tcp_stream_msg.h
    TcpStreamResponse::getResponseString() - added

src/lib/tcp/tests/tcp_listener_unittests.cc
    Added AuditTrail to faciliate verifying outcomes

src/lib/tcp/tcp_connection.cc
src/lib/tcp/tcp_connection.h
src/lib/tcp/tcp_stream_msg.h
src/lib/tcp/tests/tcp_listener_unittests.cc

index b4d82f07eaf235672f4f23a73257ecde59d2e019..22c8934f6890c3dc64317aac9c8bcefa3dc836ff 100644 (file)
@@ -261,7 +261,11 @@ TcpConnection::doWrite(TcpResponsePtr response) {
             }
         } else {
             // The connection remains open and we are done sending the response.
-            setupIdleTimer();
+            // If the response sent handler returns true then we should start the
+            // idle timer.
+            if (responseSent(response)) {
+                setupIdleTimer();
+            }
         }
     } catch (...) {
         stopThisConnection();
index e79256acef2145168f8234197905a6a6fce7028f..ce31b94ab94f58dc3045c5b6154f7d57b8f409fb 100644 (file)
@@ -298,6 +298,12 @@ public:
     /// @throw BadValue if the parameter is not greater than zero.
     void setReadMax(const size_t read_max);
 
+    /// @brief Determines behavior after a response has been sent.
+    ///
+    /// @param response Pointer to the response sent.
+    /// @return True if the idle timer should be started.
+    virtual bool responseSent(TcpResponsePtr response) = 0;
+
     /// @brief Returns an empty end point.
     ///
     /// @return an unitialized endpoint.
index cffdf9b8d26f0ac25ef49298317672734e0c564e..72db443014e09a66772eeeb5eb2a9c857d6cb2c6 100644 (file)
@@ -128,6 +128,13 @@ public:
     /// @brief Packs the response content into wire data buffer.
     virtual void pack();
 
+    /// @brief Fetches the unpacked response as a string.
+    ///
+    /// @return String containing the unpacked contents.
+    std::string getResponseString() const {
+        return (std::string(response_.begin(), response_.end()));
+    };
+
 private:
     /// @brief Unpacked response data to send.
     std::vector<uint8_t> response_;
index 2026ab83f65d621eb29b1f73e0c35a5385af0886..c20dd1d6e1001fdf2e193442731dc73b02901d47 100644 (file)
@@ -63,23 +63,126 @@ const long SHORT_IDLE_TIMEOUT = 200;
 /// @brief Test timeout (ms).
 const long TEST_TIMEOUT = 10000;
 
+/// @brief Describes stream message sent over a connection.
+class AuditEntry {
+public:
+    enum Direction {
+        INBOUND,  // data received
+        OUTBOUND  // data sent
+    };
+
+    /// @brief Constructor
+    ///
+    /// @param connection_id Id of the client to whom the entry pertains
+    /// @param direction INBOUND for data received, OUTBOUND for data sent
+    /// @param data string form of the data involved
+    AuditEntry(size_t connection_id, const AuditEntry::Direction& direction, const std::string& data)
+        : connection_id_(connection_id), direction_(direction), data_(data) {
+    }
+
+    /// @brief Equality operator.
+    ///
+    /// @param other value to be compared.
+    bool operator==(const AuditEntry& other) const {
+        return ((connection_id_ == other.connection_id_) &&
+                (direction_ == other.direction_) &&
+                (data_ == other.data_));
+    }
+
+    /// @brief Unique client identifier.
+    size_t connection_id_;
+
+    /// @brief Indicates which direction the data traveled
+    Direction direction_;
+
+    /// @brief Contains the data sent or received.
+    std::string data_;
+};
+
+std::ostream&
+operator<<(std::ostream& os, const AuditEntry& entry) {
+    os << "{ " << entry.connection_id_ << ", "
+       << (entry.direction_ == AuditEntry::INBOUND ? "I" : "O") << ", "
+       << entry.data_ << " }";
+    return (os);
+}
+
+/// @brief Contains the data receipt/transmission history for an arbitrary number
+/// of connections.
+class AuditTrail {
+public:
+    /// @brief Adds an entry to the audit trail.
+    ///
+    /// @param connection_id Id of the client to whom the entry pertains
+    /// @param direction INBOUND for data received, OUTBOUND for data sent
+    /// @param data string form of the data involved
+    void addEntry(size_t connection_id, const AuditEntry::Direction& direction, const std::string& data) {
+        // will need a mutex
+        entries_.push_back(AuditEntry(connection_id, direction, data));
+    }
+
+    /// @brief Returns a list of AuditEntry(s) for a given connection.
+    ///
+    /// @param connection_id Id of the desired connection
+    /// @return A list of entries for the connection or an empty list if none are found.
+    std::list<AuditEntry> getConnectionTrail(size_t connection_id) {
+        std::list<AuditEntry> conn_entries;
+        for (auto entry_it = entries_.begin(); entry_it != entries_.end(); ++entry_it) {
+            if ((*entry_it).connection_id_ == connection_id) {
+                conn_entries.push_back(*entry_it);
+            }
+        }
+
+        return (conn_entries);
+    }
+
+    /// @brief Dumps the audit trail as a string.
+    std::string dump() {
+        std::stringstream ss;
+        for (auto entry_it = entries_.begin(); entry_it != entries_.end(); ++entry_it) {
+            ss << (*entry_it) << std::endl;
+        }
+
+        return (ss.str());
+    }
+
+    /// @brief Contains the audit entries.
+    std::list<AuditEntry> entries_;
+};
+
+/// @brief Defines a pointer to an AuditTrail
+typedef boost::shared_ptr<AuditTrail> AuditTrailPtr;
+
+/// @brief Derivation of TcpConnection used for testing.
 class TcpTestConnection : public TcpConnection {
 public:
+    /// @brief Constructor
     TcpTestConnection(IOService& io_service,
                       const TcpConnectionAcceptorPtr& acceptor,
                       const TlsContextPtr& tls_context,
                       TcpConnectionPool& connection_pool,
                       const TcpConnectionAcceptorCallback& acceptor_callback,
                       const TcpConnectionFilterCallback& filter_callback,
-                      const long idle_timeout)
+                      const long idle_timeout,
+                      size_t connection_id,
+                      AuditTrailPtr audit_trail)
      : TcpConnection(io_service, acceptor, tls_context, connection_pool,
-                     acceptor_callback, filter_callback, idle_timeout) {
+                     acceptor_callback, filter_callback, idle_timeout),
+                     connection_id_(connection_id), audit_trail_(audit_trail) {
     }
 
+    /// @brief Creats a new empty request ready to receive data.
     virtual TcpRequestPtr createRequest() {
         return (TcpStreamRequestPtr(new TcpStreamRequest()));
     }
 
+    /// @brief Processes a completely received request.
+    ///
+    /// Adds the request to the audit trail, then forms and sends a response.
+    /// If the request is "I am done", the response is "good bye" which should instruct
+    /// the client to disconnect.
+    ///
+    /// @param request Request to process.
     virtual void requestReceived(TcpRequestPtr request) {
         TcpStreamRequestPtr req = boost::dynamic_pointer_cast<TcpStreamRequest>(request);
         if (!req) {
@@ -88,6 +191,8 @@ public:
 
         req->unpack();
         auto request_str = req->getRequestString();
+        audit_trail_->addEntry(connection_id_, AuditEntry::INBOUND, request_str);
+
         std::ostringstream os;
         if (request_str == "I am done") {
             os << "good bye";
@@ -100,6 +205,29 @@ public:
         resp->pack();
         asyncSendResponse(resp);
     }
+
+    /// @brief Processes a response once it has been sent.
+    ///
+    /// Adds the response to the audit trail and returns true, signifying
+    /// that the connection should start the idle timer.
+    ///
+    /// @param response Response that was sent to the remote endpoint.
+    virtual bool responseSent(TcpResponsePtr response) {
+        TcpStreamResponsePtr resp = boost::dynamic_pointer_cast<TcpStreamResponse>(response);
+        if (!resp) {
+            isc_throw(isc::Unexpected, "resp not a TcpStreamResponse");
+        }
+
+        audit_trail_->addEntry(connection_id_, AuditEntry::OUTBOUND, resp->getResponseString());
+        return (true);
+    }
+
+private:
+    /// @brief Id of this connection.
+    size_t connection_id_;
+
+    /// @brief Provides request/response history.
+    AuditTrailPtr audit_trail_;
 };
 
 /// @brief Implementation of the TCPListener used in tests.
@@ -107,7 +235,7 @@ public:
 /// Implements simple stream in/out listener.
 class TcpTestListener : public TcpListener {
 public:
-
+    /// @brief Constructor
     TcpTestListener(IOService& io_service,
                     const IOAddress& server_address,
                     const unsigned short server_port,
@@ -117,7 +245,8 @@ public:
                     const size_t read_max = 32 * 1024)
         : TcpListener(io_service, server_address, server_port,
                       tls_context, idle_timeout, filter_callback),
-                      read_max_(read_max) {
+                      read_max_(read_max), next_connection_id_(0),
+                      audit_trail_(new AuditTrail()) {
     }
 
 protected:
@@ -132,16 +261,22 @@ protected:
     virtual TcpConnectionPtr createConnection(
             const TcpConnectionAcceptorCallback& acceptor_callback,
             const TcpConnectionFilterCallback& connection_filter) {
-        TcpConnectionPtr conn(new TcpTestConnection(io_service_, acceptor_,
-                                                    tls_context_, connections_,
-                                                    acceptor_callback, connection_filter,
-                                                    idle_timeout_));
+        TcpConnectionPtr conn(new TcpTestConnection(io_service_, acceptor_, tls_context_,
+                                                    connections_, acceptor_callback,
+                                                    connection_filter, idle_timeout_,
+                                                    ++next_connection_id_,  audit_trail_));
         conn->setReadMax(read_max_);
         return (conn);
     }
 
     /// @brief Maximum size of a single socket read
     size_t read_max_;
+
+    /// @brief Id to use for the next connection.
+    size_t next_connection_id_;
+
+public:
+    AuditTrailPtr audit_trail_;
 };
 
 /// @brief Test fixture class for @ref TcpListener.
@@ -296,6 +431,19 @@ TEST_F(TcpListenerTest, listen) {
     EXPECT_TRUE(client->receiveDone());
     EXPECT_FALSE(client->expectedEof());
 
+    // Verify the audit trail for the connection.
+    // Sanity check to make sure we don't have more entries than we expect.
+    ASSERT_EQ(listener.audit_trail_->entries_.size(), 2);
+
+    // Create the list of expected entries.
+    std::list<AuditEntry> expected_entries {
+        { 1, AuditEntry::INBOUND, "I am done" },
+        { 1, AuditEntry::OUTBOUND, "good bye" }
+    };
+
+    // Verify the audit trail.
+    ASSERT_EQ(expected_entries, listener.audit_trail_->getConnectionTrail(1));
+
     listener.stop();
     io_service_.poll();
 }
@@ -371,9 +519,20 @@ TEST_F(TcpListenerTest, multipleClientsListen) {
     ASSERT_NO_THROW(runIOService());
     ASSERT_EQ(num_clients, clients_.size());
 
+    size_t connection_id = 1;
     for (auto client = clients_.begin(); client != clients_.end(); ++client) {
         EXPECT_TRUE((*client)->receiveDone());
         EXPECT_FALSE((*client)->expectedEof());
+        // Create the list of expected entries.
+        std::list<AuditEntry> expected_entries {
+            { connection_id, AuditEntry::INBOUND, "I am done" },
+            { connection_id, AuditEntry::OUTBOUND, "good bye" }
+        };
+
+        // Fetch the entries for this connection.
+        auto entries = listener.audit_trail_->getConnectionTrail(connection_id);
+        ASSERT_EQ(expected_entries, entries);
+        ++connection_id;
     }
 
     listener.stop();
@@ -400,10 +559,29 @@ TEST_F(TcpListenerTest, multipleRequetsPerClients) {
     ASSERT_EQ(num_clients, clients_.size());
 
     std::list<std::string>expected_responses{ "echo one", "echo two", "echo three", "good bye"};
+    size_t connection_id = 1;
     for (auto client = clients_.begin(); client != clients_.end(); ++client) {
         EXPECT_TRUE((*client)->receiveDone());
         EXPECT_FALSE((*client)->expectedEof());
         EXPECT_EQ(expected_responses, (*client)->getResponses());
+
+        // Verify the connection's audit trail.
+        // Create the list of expected entries.
+        std::list<AuditEntry> expected_entries {
+            { connection_id, AuditEntry::INBOUND, "one" },
+            { connection_id, AuditEntry::OUTBOUND, "echo one" },
+            { connection_id, AuditEntry::INBOUND, "two" },
+            { connection_id, AuditEntry::OUTBOUND, "echo two" },
+            { connection_id, AuditEntry::INBOUND, "three" },
+            { connection_id, AuditEntry::OUTBOUND, "echo three" },
+            { connection_id, AuditEntry::INBOUND, "I am done" },
+            { connection_id, AuditEntry::OUTBOUND, "good bye" }
+        };
+
+        // Fetch the entries for this connection.
+        auto entries = listener.audit_trail_->getConnectionTrail(connection_id);
+        ASSERT_EQ(expected_entries, entries);
+        ++connection_id;
     }
 
     listener.stop();
@@ -441,10 +619,25 @@ TEST_F(TcpListenerTest, filterClientsTest) {
             // These clients should have been accepted and received responses.
             EXPECT_TRUE((*client)->receiveDone());
             EXPECT_FALSE((*client)->expectedEof());
+
+            // Now verify the AuditTrail.
+            // Create the list of expected entries.
+            std::list<AuditEntry> expected_entries {
+                { i+1, AuditEntry::INBOUND, "I am done" },
+                { i+1, AuditEntry::OUTBOUND, "good bye" }
+            };
+
+            auto entries = listener.audit_trail_->getConnectionTrail(i+1);
+            ASSERT_EQ(expected_entries, entries);
+
         } else {
             // These clients should have been rejected and gotten EOF'd.
             EXPECT_FALSE((*client)->receiveDone());
             EXPECT_TRUE((*client)->expectedEof());
+
+            // Verify connection recorded no audit entries.
+            auto entries = listener.audit_trail_->getConnectionTrail(i+1);
+            ASSERT_EQ(entries.size(), 0);
         }
 
         ++i;