]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[(no branch, rebasing 1798-remove-tls-stream-clear-operation)] [#1798] Checkpoint...
authorFrancis Dupont <fdupont@isc.org>
Tue, 13 Apr 2021 21:58:03 +0000 (23:58 +0200)
committerFrancis Dupont <fdupont@isc.org>
Tue, 11 May 2021 16:02:34 +0000 (18:02 +0200)
src/lib/asiolink/tls_socket.h
src/lib/http/client.cc
src/lib/http/tests/server_client_unittests.cc

index 67a9233f48db484d8367f4391059c804de3c600d..b5ce4b6e364da177669ddc19f174aab1a2601557 100644 (file)
@@ -276,12 +276,6 @@ TLSSocket<C>::TLSSocket(IOService& service, TlsContextPtr context) :
 
 template <typename C> void
 TLSSocket<C>::open(const IOEndpoint* endpoint, C& callback) {
-    // If socket is open on this end but has been closed by the peer,
-    // we need to reconnect.
-    if (socket_.is_open() && !isUsable()) {
-        socket_.close();
-    }
-
     // Ignore opens on already-open socket.  Don't throw a failure because
     // of uncertainties as to what precedes when using asynchronous I/O.
     // Also allows us a treat a passed-in socket as a self-managed socket.
index ccb8ecad3aa36300b130e28ecf933ea31e8623ca..ced8422d79456b7b81463cbc9a8099fbbb90ab52 100644 (file)
@@ -90,10 +90,10 @@ typedef boost::shared_ptr<ConnectionPool> ConnectionPoolPtr;
 /// @brief Client side HTTP connection to the server.
 ///
 /// Each connection is established with a unique destination identified by the
-/// specified URL. Multiple requests to the same destination can be sent over
-/// the same connection, if the connection is persistent. If the server closes
-/// the TCP connection (e.g. after sending a response), the connection can
-/// be re-established (using the same @c Connection object).
+/// specified URL and TLS context. Multiple requests to the same destination
+/// can be sent overthe same connection, if the connection is persistent.
+/// If the server closes the TCP connection (e.g. after sending a response),
+/// the connection is closed.
 ///
 /// If new request is created while the previous request is still in progress,
 /// the new request is stored in the FIFO queue. The queued requests to the
@@ -154,6 +154,17 @@ public:
     /// @return true if transaction has been initiated, false otherwise.
     bool isTransactionOngoing() const;
 
+    /// @brief Checks if the socket has been closed.
+    ///
+    /// @return true if the socket has been closed.
+    bool isClosed() const;
+
+    /// @brief Checks if the peer has closed the socket at its side.
+    ///
+    /// If the socket is open but is not usable the peer has closed
+    /// the socket at its side so we close it.
+    void isClosedByPeer();
+
     /// @brief Checks if a socket descriptor belongs to this connection.
     ///
     /// @param socket_fd socket descriptor to check
@@ -291,6 +302,8 @@ private:
 
     /// @brief Asynchronously performs the TLS handshake.
     ///
+    /// The TLS handshake is performed once on TLS sockets.
+    ///
     /// @param transid Current transaction id.
     void doHandshake(const uint64_t transid);
 
@@ -380,6 +393,9 @@ private:
     /// @brief URL for this connection.
     Url url_;
 
+    /// @brief TLS context for this connection.
+    TlsContextPtr tls_context_;
+
     /// @brief TCP socket to be used for this connection.
     std::unique_ptr<TCPSocket<SocketCallback> > tcp_socket_;
 
@@ -419,6 +435,12 @@ private:
     /// @brief Flag to indicate that a transaction is running.
     std::atomic<bool> started_;
 
+    /// @brief Flag to indicate that the TLS handshake has to be performed.
+    std::atomic<bool> need_handshake_;
+
+    /// @brief Flag to indicate that the socket was closed.
+    std::atomic<bool> closed_;
+
     /// @brief Mutex to protect the internal state.
     std::mutex mutex_;
 };
@@ -454,18 +476,31 @@ public:
         closeAll();
     }
 
-    /// @brief Process next queued request for the given URL.
+    /// @brief Process next queued request for the given URL and TLS context.
     ///
     /// @param url URL for which next queued request should be processed.
-    void processNextRequest(const Url& url) {
+    /// @param tls_context TLS context for which next queued request
+    /// should be processed.
+    void processNextRequest(const Url& url, const TlsContextPtr& tls_context) {
         if (MultiThreadingMgr::instance().getMode()) {
             std::lock_guard<std::mutex> lk(mutex_);
-            return (processNextRequestInternal(url));
+            return (processNextRequestInternal(url, tls_context));
         } else {
-            return (processNextRequestInternal(url));
+            return (processNextRequestInternal(url, tls_context));
         }
     }
 
+    /// @brief Schedule processing of next queued request.
+    ///
+    /// @param url URL for which next queued request should be processed.
+    /// @param tls_context TLS context for which next queued request
+    /// should be processed.
+    void postProcessNextRequest(const Url& url,
+                                const TlsContextPtr& tls_context) {
+        io_service_.post(std::bind(&ConnectionPool::processNextRequest,
+                                   shared_from_this(), url, tls_context));
+    }
+
     /// @brief Queue next request for sending to the server.
     ///
     /// A new transaction is started immediately, if there is no other request
@@ -543,23 +578,35 @@ public:
 
 private:
 
-    /// @brief Process next queued request for the given URL.
+    /// @brief Process next queued request for the given URL and TLS context.
     ///
     /// This method should be called in a thread safe context.
     ///
     /// @param url URL for which next queued request should be retrieved.
-    void processNextRequestInternal(const Url& url) {
+    /// @param tls_context TLS context for which next queued request
+    /// should be processed.
+    void processNextRequestInternal(const Url& url,
+                                    const TlsContextPtr& tls_context) {
         // Check if there is a queue for this URL. If there is no queue, there
         // is no request queued either.
-        DestinationPtr destination = findDestination(url);
+        DestinationPtr destination = findDestination(url, tls_context);
         if (destination) {
+            // Remove closed connections.
+            destination->garbageCollectConnections();
             if (!destination->queueEmpty()) {
                 // We have at least one queued request. Do we have an
                 // idle connection?
                 ConnectionPtr connection = destination->getIdleConnection();
                 if (!connection) {
-                    // No idle connections, so just return.
-                    return;
+                    // No idle connections.
+                    if (destination->connectionsFull()) {
+                        return;
+                    }
+                    // Room to make another connection with this destination,
+                    // so make one.
+                    connection.reset(new Connection(io_service_, tls_context,
+                                                    shared_from_this(), url));
+                    destination->addConnection(connection);
                 }
 
                 // Dequeue the oldest request and start a transaction for it using
@@ -607,13 +654,15 @@ private:
                               const HttpClient::CloseHandler& close_callback) {
         ConnectionPtr connection;
         // Find the destination for the requested URL.
-        DestinationPtr destination = findDestination(url);
+        DestinationPtr destination = findDestination(url, tls_context);
         if (destination) {
+            // Remove closed connections.
+            destination->garbageCollectConnections();
             // Found it, look for an idle connection.
             connection = destination->getIdleConnection();
         } else {
             // Doesn't exist yet so it's a new destination.
-            destination = addDestination(url);
+            destination = addDestination(url, tls_context);
         }
 
         if (!connection) {
@@ -736,16 +785,22 @@ private:
         HttpClient::CloseHandler close_callback_;
     };
 
+    /// @brief Type of URL and TLS context pairs.
+    typedef std::pair<Url, TlsContextPtr> DestinationDescriptor;
+
     /// @brief Encapsulates connections and requests for a given URL
     class Destination {
     public:
         /// @brief Constructor
         ///
         /// @param url server URL of this destination
+        /// @param tls_context server TLS context of this destination
         /// @param max_connections maximum number of concurrent connections
         /// allowed for in the list URL
-        Destination(Url url, size_t max_connections)
-        : url_(url), max_connections_(max_connections), connections_(), queue_() { }
+        Destination(Url url, TlsContextPtr tls_context, size_t max_connections)
+            : url_(url), tls_context_(tls_context),
+              max_connections_(max_connections), connections_(), queue_() {
+        }
 
         /// @brief Destructor
         ~Destination() {
@@ -793,11 +848,27 @@ private:
             connections_.clear();
         }
 
+        /// @brief Removes closed connections.
+        ///
+        /// This method should be called before @ref getIdleConnection.
+        /// @note This should be called in a thread safe context.
+        void garbageCollectConnections() {
+            for (auto it = connections_.begin(); it != connections_.end();) {
+                (*it)->isClosedByPeer();
+                if (!(*it)->isClosed()) {
+                    ++it;
+                } else {
+                    it = connections_.erase(it);
+                }
+            }
+        }
+
         /// @brief Finds the first idle connection.
         ///
         /// Iterates over the existing connections and returns the
         /// first connection which is not currently in a transaction.
         ///
+        /// @note @ref garbageCollectConnections should be called before.
         /// @return The first idle connection or an empty pointer if
         /// all connections are busy.
         ConnectionPtr getIdleConnection() {
@@ -854,13 +925,6 @@ private:
             return (max_connections_);
         }
 
-        /// @brief Fetches the URL.
-        ///
-        /// @return the URL.
-        const Url& getUrl() const {
-            return (url_);
-        }
-
         /// @brief Indicates if request queue is empty.
         ///
         /// @return true if there are no requests queued.
@@ -892,6 +956,9 @@ private:
         /// @brief URL supported by this destination.
         Url url_;
 
+        /// @brief TLS context to use with this destination.
+        TlsContextPtr tls_context_;
+
         /// @brief Maximum number of concurrent connections for this destination.
         size_t max_connections_;
 
@@ -905,27 +972,34 @@ private:
     /// @brief Pointer to a Destination.
     typedef boost::shared_ptr<Destination> DestinationPtr;
 
-    /// @brief Creates a new destination for the given URL.
+    /// @brief Creates a new destination for the given URL and TLS context.
     ///
     /// @param url URL of the new destination.
+    /// @param tls_context TLS context for the new destination.
     ///
     /// @return Pointer to the newly created destination.
     /// @note Must be called from within a thread-safe context.
-    DestinationPtr addDestination(const Url& url) {
-        DestinationPtr destination(new Destination(url, max_url_connections_));
-        destinations_[url] = destination;
+    DestinationPtr addDestination(const Url& url,
+                                  const TlsContextPtr& tls_context) {
+        const DestinationDescriptor& desc = std::make_pair(url, tls_context);
+        DestinationPtr destination(new Destination(url, tls_context,
+                                                   max_url_connections_));
+        destinations_[desc] = destination;
         return (destination);
     }
 
-    /// @brief Fetches a destination by URL
+    /// @brief Fetches a destination by URL and TLS context.
     ///
     /// @param url URL of the destination desired.
+    /// @param tls_context TLS context for the destination desired.
     ///
     /// @return pointer the desired destination, empty pointer
     /// if the destination does not exist.
     /// @note Must be called from within a thread-safe context.
-    DestinationPtr findDestination(const Url& url) const {
-        auto it = destinations_.find(url);
+    DestinationPtr findDestination(const Url& url,
+                                   const TlsContextPtr& tls_context) const {
+        const DestinationDescriptor& desc = std::make_pair(url, tls_context);
+        auto it = destinations_.find(desc);
         if (it != destinations_.end()) {
             return (it->second);
         }
@@ -933,16 +1007,19 @@ private:
         return (DestinationPtr());
     }
 
-    /// @brief Removes a destination by URL
+    /// @brief Removes a destination by URL and TLS context.
     ///
     /// Closes all of the destination's connections and
     /// discards all of its queued requests while removing
     /// the destination from the list of known destinations.
     ///
     /// @param url URL of the destination to be removed.
+    /// @param tls_context TLS context for the destination to be removed.
     /// @note Must be called from within a thread-safe context.
-    void removeDestination(const Url& url) {
-        auto it = destinations_.find(url);
+    void removeDestination(const Url& url,
+                           const TlsContextPtr& tls_context) {
+        const DestinationDescriptor& desc = std::make_pair(url, tls_context);
+        auto it = destinations_.find(desc);
         if (it != destinations_.end()) {
             it->second->closeAllConnections();
             destinations_.erase(it);
@@ -952,13 +1029,13 @@ private:
     /// @brief A reference to the IOService that drives socket IO.
     IOService& io_service_;
 
-    /// @brief Map of Destinations by URL.
-    std::map<Url, DestinationPtr> destinations_;
+    /// @brief Map of Destinations by URL and TLS context.
+    std::map<DestinationDescriptor, DestinationPtr> destinations_;
 
     /// @brief Mutex to protect the internal state.
     std::mutex mutex_;
 
-    /// @brief Maximum number of connections per URL.
+    /// @brief Maximum number of connections per URL and TLS context.
     size_t max_url_connections_;
 };
 
@@ -966,15 +1043,18 @@ Connection::Connection(IOService& io_service,
                        const TlsContextPtr& tls_context,
                        const ConnectionPoolPtr& conn_pool,
                        const Url& url)
-    : conn_pool_(conn_pool), url_(url), tcp_socket_(), tls_socket_(),
-      timer_(io_service), current_request_(), current_response_(),
-      parser_(), current_callback_(), buf_(), input_buf_(),
-      current_transid_(0), close_callback_(), started_(false) {
+    : conn_pool_(conn_pool), url_(url), tls_context_(tls_context),
+      tcp_socket_(), tls_socket_(), timer_(io_service),
+      current_request_(), current_response_(), parser_(),
+      current_callback_(), buf_(), input_buf_(), current_transid_(0),
+      close_callback_(), started_(false), need_handshake_(false),
+      closed_(false) {
     if (!tls_context) {
         tcp_socket_.reset(new asiolink::TCPSocket<SocketCallback>(io_service));
     } else {
         tls_socket_.reset(new asiolink::TLSSocket<SocketCallback>(io_service,
                                                                   tls_context));
+        need_handshake_ = true;
     }
 }
 
@@ -1013,6 +1093,34 @@ Connection::closeCallback(const bool clear) {
     }
 }
 
+void
+Connection::isClosedByPeer() {
+    // If the socket is open we check if it is possible to transmit
+    // the data over this socket by reading from it with message
+    // peeking. If the socket is not usable, we close it and then
+    // re-open it. There is a narrow window of time between checking
+    // the socket usability and actually transmitting the data over
+    // this socket, when the peer may close the connection. In this
+    // case we'll need to re-transmit but we don't handle it here.
+    if (tcp_socket_) {
+        if (tcp_socket_->getASIOSocket().is_open() &&
+            !tcp_socket_->isUsable()) {
+            closeCallback();
+            closed_ = true;
+            tcp_socket_->close();
+        }
+    } else if (tls_socket_) {
+        if (tls_socket_->getASIOSocket().is_open() &&
+            !tls_socket_->isUsable()) {
+            closeCallback();
+            closed_ = true;
+            tls_socket_->close();
+        }
+    } else {
+        isc_throw(Unexpected, "internal error: can't find the sending socket");
+    }
+}
+
 void
 Connection::doTransaction(const HttpRequestPtr& request,
                           const HttpResponsePtr& response,
@@ -1056,29 +1164,6 @@ Connection::doTransactionInternal(const HttpRequestPtr& request,
 
         buf_ = request->toString();
 
-        // If the socket is open we check if it is possible to transmit the data
-        // over this socket by reading from it with message peeking. If the socket
-        // is not usable, we close it and then re-open it. There is a narrow window of
-        // time between checking the socket usability and actually transmitting the
-        // data over this socket, when the peer may close the connection. In this
-        // case we'll need to re-transmit but we don't handle it here.
-        if (tcp_socket_) {
-            if (tcp_socket_->getASIOSocket().is_open() &&
-                !tcp_socket_->isUsable()) {
-                closeCallback();
-                tcp_socket_->close();
-            }
-        } else if (tls_socket_) {
-            if (tls_socket_->getASIOSocket().is_open() &&
-                !tls_socket_->isUsable()) {
-                closeCallback();
-                tls_socket_->close();
-            }
-        } else {
-            isc_throw(Unexpected,
-                      "internal error: can't find the sending socket");
-        }
-
         LOG_DEBUG(http_logger, isc::log::DBGLVL_TRACE_DETAIL,
                   HTTP_CLIENT_REQUEST_SEND)
             .arg(request->toBriefString())
@@ -1113,8 +1198,7 @@ Connection::doTransactionInternal(const HttpRequestPtr& request,
         }
 
         // Should never reach this point.
-        isc_throw(Unexpected,
-                  "internal error: can't find a socket to open");
+        isc_throw(Unexpected, "internal error: can't find a socket to open");
 
     } catch (const std::exception& ex) {
         // Re-throw with the expected exception type.
@@ -1137,6 +1221,7 @@ Connection::closeInternal() {
     // Pass in true to discard the callback.
     closeCallback(true);
 
+    closed_ = true;
     timer_.cancel();
     if (tcp_socket_) {
         tcp_socket_->close();
@@ -1153,6 +1238,11 @@ Connection::isTransactionOngoing() const {
     return (started_);
 }
 
+bool
+Connection::isClosed() const {
+    return (closed_);
+}
+
 bool
 Connection::isMySocket(int socket_fd) const {
     if (tcp_socket_) {
@@ -1274,16 +1364,11 @@ Connection::terminateInternal(const boost::system::error_code& ec,
         resetState();
     }
 
-    // Check if there are any requests queued for this connection and start
+    // Check if there are any requests queued for this destination and start
     // another transaction if there is at least one.
     ConnectionPoolPtr conn_pool = conn_pool_.lock();
     if (conn_pool) {
-        if (MultiThreadingMgr::instance().getMode()) {
-            UnlockGuard<std::mutex> lock(mutex_);
-            conn_pool->processNextRequest(url_);
-        } else {
-            conn_pool->processNextRequest(url_);
-        }
+        conn_pool->postProcessNextRequest(url_, tls_context_);
     }
 }
 
@@ -1297,8 +1382,8 @@ Connection::scheduleTimer(const long request_timeout) {
 
 void
 Connection::doHandshake(const uint64_t transid) {
-    // Skip the handshake if the socket is not a TLS one.
-    if (!tls_socket_) {
+    // Skip the handshake if it is not needed.
+    if (!need_handshake_) {
         doSend(transid);
         return;
     }
@@ -1422,6 +1507,7 @@ void
 Connection::handshakeCallback(HttpClient::ConnectHandler handshake_callback,
                               const uint64_t transid,
                               const boost::system::error_code& ec) {
+    need_handshake_ = false;
     if (checkPrematureTimeout(transid)) {
         return;
     }
index 1f325c05e763744e02ede592ed8dc6eb56102a5a..b6af68e24ab092ca4923a47203c68f9e80590729 100644 (file)
@@ -1443,13 +1443,16 @@ public:
         // Specify the URL of the server.
         Url url("http://127.0.0.1:18123");
 
+        // Specify the TLS context of the server.
+        TlsContextPtr tls_context;
+
         // Generate first request.
         PostHttpRequestJsonPtr request1 = createRequest("sequence", 1);
         HttpResponseJsonPtr response1(new HttpResponseJson());
 
         // Use very short timeout to make sure that it occurs before we actually
         // run the transaction.
-        ASSERT_NO_THROW(client.asyncSendRequest(url, TlsContextPtr(),
+        ASSERT_NO_THROW(client.asyncSendRequest(url, tls_context,
                                                 request1, response1,
             [](const boost::system::error_code& ec,
                const HttpResponsePtr& response,
@@ -1466,7 +1469,7 @@ public:
         if (queue_two_requests) {
             PostHttpRequestJsonPtr request2 = createRequest("sequence", 2);
             HttpResponseJsonPtr response2(new HttpResponseJson());
-            ASSERT_NO_THROW(client.asyncSendRequest(url, TlsContextPtr(),
+            ASSERT_NO_THROW(client.asyncSendRequest(url, tls_context,
                                                     request2, response2,
                 [](const boost::system::error_code& ec,
                    const HttpResponsePtr& response,