]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[#26,!106] Guard HTTP client connect with timeout.
authorMarcin Siodelski <marcin@isc.org>
Mon, 5 Nov 2018 13:12:53 +0000 (14:12 +0100)
committerTomek Mrugalski <tomasz@isc.org>
Wed, 7 Nov 2018 09:30:21 +0000 (16:30 +0700)
src/lib/http/client.cc
src/lib/http/client.h
src/lib/http/tests/server_client_unittests.cc

index f775330d64cd2561cb6aae6499ec1fb01eae8a03..b1474fb4e72cb1de11c3405956783356e37d470e 100644 (file)
@@ -122,8 +122,11 @@ public:
     /// @param request_timeout Request timeout in milliseconds.
     /// @param callback Pointer to the callback function to be invoked when the
     /// transaction completes.
+    /// @param connect_callback Pointer to the callback function to be invoked when
+    /// the client connects to the server.
     void doTransaction(const HttpRequestPtr& request, const HttpResponsePtr& response,
-                       const long request_timeout, const HttpClient::RequestHandler& callback);
+                       const long request_timeout, const HttpClient::RequestHandler& callback,
+                       const HttpClient::ConnectHandler& connect_callback);
 
     /// @brief Closes the socket and cancels the request timer.
     void close();
@@ -173,9 +176,10 @@ private:
     /// If the connection is successfully established, this callback will start
     /// to asynchronously send the request over the socket.
     ///
-    /// @param request_timeout Request timeout specified for this transaction.
+    /// @param Pointer to the callback to be invoked when client connects to
+    /// the server.
     /// @param ec Error code being a result of the connection attempt.
-    void connectCallback(const long request_timeout,
+    void connectCallback(HttpClient::ConnectHandler connect_callback,
                          const boost::system::error_code& ec);
 
     /// @brief Local callback invoked when an attempt to send a portion of data
@@ -269,6 +273,8 @@ public:
     /// be stored.
     /// @param request_timeout Requested timeout for the transaction.
     /// @param callback Pointer to the user callback for this request.
+    /// @param connect_callback Pointer to the user callback invoked when
+    /// the client connects to the server.
     ///
     /// @return true if the request for the given URL has been retrieved,
     /// false if there are no more requests queued for this URL.
@@ -276,7 +282,8 @@ public:
                         HttpRequestPtr& request,
                         HttpResponsePtr& response,
                         long& request_timeout,
-                        HttpClient::RequestHandler& callback) {
+                        HttpClient::RequestHandler& callback,
+                        HttpClient::ConnectHandler& connect_callback) {
         // Check if there is a queue for this URL. If there is no queue, there
         // is no request queued either.
         auto it = queue_.find(url);
@@ -289,6 +296,7 @@ public:
                 response = desc.response_;
                 request_timeout = desc.request_timeout_,
                 callback = desc.callback_;
+                connect_callback = desc.connect_callback_;
                 return (true);
             }
         }
@@ -307,13 +315,16 @@ public:
     /// stored.
     /// @param request_timeout Requested timeout for the transaction in
     /// milliseconds.
-    /// @param callback Pointer to the user callback to be invoked when the
+    /// @param request_callback Pointer to the user callback to be invoked when the
     /// transaction ends.
+    /// @param connect_callback Pointer to the user callback to be invoked when the
+    /// client connects to the server.
     void queueRequest(const Url& url,
                       const HttpRequestPtr& request,
                       const HttpResponsePtr& response,
                       const long request_timeout,
-                      const HttpClient::RequestHandler& callback) {
+                      const HttpClient::RequestHandler& request_callback,
+                      const HttpClient::ConnectHandler& connect_callback) {
         auto it = conns_.find(url);
         if (it != conns_.end()) {
             ConnectionPtr conn = it->second;
@@ -322,12 +333,13 @@ public:
                 // Connection is busy, so let's queue the request.
                 queue_[url].push(RequestDescriptor(request, response,
                                                    request_timeout,
-                                                   callback));
+                                                   request_callback,
+                                                   connect_callback));
 
             } else {
                 // Connection is idle, so we can start the transaction.
                 conn->doTransaction(request, response, request_timeout,
-                                    callback);
+                                    request_callback, connect_callback);
             }
 
         } else {
@@ -335,7 +347,8 @@ public:
             // it and start the transaction.
             ConnectionPtr conn(new Connection(io_service_, shared_from_this(),
                                               url));
-            conn->doTransaction(request, response, request_timeout, callback);
+            conn->doTransaction(request, response, request_timeout, request_callback,
+                                connect_callback);
             conns_[url] = conn;
         }
     }
@@ -389,13 +402,17 @@ private:
         /// be stored.
         /// @param request_timeout Requested timeout for the transaction.
         /// @param callback Pointer to the user callback.
+        /// @param connect_callback pointer to the user callback to be invoked
+        /// when the client connects to the server.
         RequestDescriptor(const HttpRequestPtr& request,
                           const HttpResponsePtr& response,
                           const long request_timeout,
-                          const HttpClient::RequestHandler& callback)
+                          const HttpClient::RequestHandler& callback,
+                          const HttpClient::ConnectHandler& connect_callback)
             : request_(request), response_(response),
               request_timeout_(request_timeout),
-              callback_(callback) {
+              callback_(callback),
+              connect_callback_(connect_callback) {
         }
 
         /// @brief Holds pointer to the request.
@@ -406,6 +423,8 @@ private:
         long request_timeout_;
         /// @brief Holds pointer to the user callback.
         HttpClient::RequestHandler callback_;
+        /// @brief Holds pointer to the user callback for connect.
+        HttpClient::ConnectHandler connect_callback_;
     };
 
     /// @brief Holds the queue of requests for different URLs.
@@ -436,7 +455,8 @@ void
 Connection::doTransaction(const HttpRequestPtr& request,
                           const HttpResponsePtr& response,
                           const long request_timeout,
-                          const HttpClient::RequestHandler& callback) {
+                          const HttpClient::RequestHandler& callback,
+                          const HttpClient::ConnectHandler& connect_callback) {
     try {
         current_request_ = request;
         current_response_ = response;
@@ -467,13 +487,16 @@ Connection::doTransaction(const HttpRequestPtr& request,
             .arg(HttpMessageParserBase::logFormatHttpMessage(request->toString(),
                                                              MAX_LOGGED_MESSAGE_SIZE));
 
+        // Setup request timer.
+        scheduleTimer(request_timeout);
+
         /// @todo We're getting a hostname but in fact it is expected to be an IP address.
         /// We should extend the TCPEndpoint to also accept names. Currently, it will fall
         /// over for names.
         TCPEndpoint endpoint(url_.getStrippedHostname(),
                              static_cast<unsigned short>(url_.getPort()));
         SocketCallback socket_cb(boost::bind(&Connection::connectCallback, shared_from_this(),
-                                             request_timeout, _1));
+                                             connect_callback, _1));
 
         // Establish new connection or use existing connection.
         socket_.open(&endpoint, socket_cb);
@@ -557,10 +580,11 @@ Connection::terminate(const boost::system::error_code& ec,
     HttpRequestPtr request;
     long request_timeout;
     HttpClient::RequestHandler callback;
+    HttpClient::ConnectHandler connect_callback;
     ConnectionPoolPtr conn_pool = conn_pool_.lock();
     if (conn_pool && conn_pool->getNextRequest(url_, request, response, request_timeout,
-                                               callback)) {
-        doTransaction(request, response, request_timeout, callback);
+                                               callback, connect_callback)) {
+        doTransaction(request, response, request_timeout, callback, connect_callback);
     }
 }
 
@@ -599,7 +623,17 @@ Connection::doReceive() {
 }
 
 void
-Connection::connectCallback(const long request_timeout, const boost::system::error_code& ec) {
+Connection::connectCallback(HttpClient::ConnectHandler connect_callback,
+                            const boost::system::error_code& ec) {
+    // Run user defined connect callback if specified.
+    if (connect_callback) {
+        // If the user defined callback indicates that the connection
+        // should not be continued.
+        if (!connect_callback(ec)) {
+            return;
+        }
+    }
+
     // In some cases the "in progress" status code may be returned. It doesn't
     // indicate an error. Sending the request over the socket is expected to
     // be successful. Getting such status appears to be highly dependent on
@@ -610,9 +644,6 @@ Connection::connectCallback(const long request_timeout, const boost::system::err
         terminate(ec);
 
     } else {
-        // Setup request timer.
-        scheduleTimer(request_timeout);
-
         // Start sending the request asynchronously.
         doSend();
     }
@@ -736,8 +767,9 @@ HttpClient::HttpClient(IOService& io_service)
 void
 HttpClient::asyncSendRequest(const Url& url, const HttpRequestPtr& request,
                              const HttpResponsePtr& response,
-                             const HttpClient::RequestHandler& callback,
-                             const HttpClient::RequestTimeout& request_timeout) {
+                             const HttpClient::RequestHandler& request_callback,
+                             const HttpClient::RequestTimeout& request_timeout,
+                             const HttpClient::ConnectHandler& connect_callback) {
     if (!url.isValid()) {
         isc_throw(HttpClientError, "invalid URL specified for the HTTP client");
     }
@@ -750,12 +782,12 @@ HttpClient::asyncSendRequest(const Url& url, const HttpRequestPtr& request,
         isc_throw(HttpClientError, "HTTP response must not be null");
     }
 
-    if (!callback) {
+    if (!request_callback) {
         isc_throw(HttpClientError, "callback for HTTP transaction must not be null");
     }
 
     impl_->conn_pool_->queueRequest(url, request, response, request_timeout.value_,
-                                    callback);
+                                    request_callback, connect_callback);
 }
 
 void
index 5731826ad96f69ad23060c0dfe5f7699e6915653..132eda3477c2ee5595d9205e97730cdfe36d54ff 100644 (file)
@@ -82,6 +82,12 @@ public:
                                const HttpResponsePtr&,
                                const std::string&)> RequestHandler;
 
+    /// @brief Optional handler invoked when client connects to the server.
+    ///
+    /// Returned boolean value indicates whether the client should continue
+    /// connecting to the server (if true) or not (false).
+    typedef std::function<bool(const boost::system::error_code&)> ConnectHandler;
+
     /// @brief Constructor.
     ///
     /// @param io_service IO service to be used by the HTTP client.
@@ -141,16 +147,21 @@ public:
     /// @param url URL where the request should be send.
     /// @param request Pointer to the object holding a request.
     /// @param response Pointer to the object where response should be stored.
-    /// @param callback Pointer to the user callback function.
+    /// @param request_callback Pointer to the user callback function invoked
+    /// when transaction ends.
     /// @param request_timeout Timeout for the transaction in milliseconds.
+    /// @param connect_callback Optional callback invoked when the client
+    /// connects to the server.
     ///
     /// @throw HttpClientError If invalid arguments were provided.
     void asyncSendRequest(const Url& url,
                           const HttpRequestPtr& request,
                           const HttpResponsePtr& response,
-                          const RequestHandler& callback,
+                          const RequestHandler& request_callback,
                           const RequestTimeout& request_timeout =
-                          RequestTimeout(10000));
+                          RequestTimeout(10000),
+                          const ConnectHandler& connect_callback =
+                          ConnectHandler());
 
     /// @brief Closes all connections.
     void stop();
index ebb5e03f50f247d49181860961e05b26f05a6701..527f739ff80d52c509f99e6f2d7bde36f78fc21e 100644 (file)
@@ -1249,5 +1249,58 @@ TEST_F(HttpClientTest, clientRequestTimeout) {
     ASSERT_NO_THROW(runIOService());
 }
 
+// Test that client times out when connection takes too long.
+TEST_F(HttpClientTest, clientConnectTimeout) {
+    // Start the server.
+    ASSERT_NO_THROW(listener_.start());
+
+    // Create the client.
+    HttpClient client(io_service_);
+
+    // Specify the URL of the server.
+    Url url("http://127.0.0.1:18123");
+
+    unsigned cb_num = 0;
+
+    PostHttpRequestJsonPtr request = createRequest("sequence", 1);
+    HttpResponseJsonPtr response(new HttpResponseJson());
+    ASSERT_NO_THROW(client.asyncSendRequest(url, request, response,
+        [this, &cb_num](const boost::system::error_code& ec,
+                        const HttpResponsePtr& response,
+                        const std::string&) {
+        if (++cb_num > 1) {
+            io_service_.stop();
+        }
+        // In this particular case we know exactly the type of the
+        // IO error returned, because the client explicitly sets this
+        // error code.
+        EXPECT_TRUE(ec.value() == boost::asio::error::timed_out);
+        // There should be no response returned.
+        EXPECT_FALSE(response);
+
+    }, HttpClient::RequestTimeout(100),
+
+       // This callback is invoked upon an attempt to connect to the
+       // server. The false value indicates to the HttpClient to not
+       // try to send a request to the server. This simulates the
+       // case of connect() taking very long and should eventually
+       // cause the transaction to time out.
+       [](const boost::system::error_code& ec) {
+           return (false);
+    }));
+
+    // Create another request after the timeout. It should be handled ok.
+    ASSERT_NO_THROW(client.asyncSendRequest(url, request, response,
+                    [this, &cb_num](const boost::system::error_code& /*ec*/, const HttpResponsePtr&,
+               const std::string&) {
+        if (++cb_num > 1) {
+            io_service_.stop();
+        }
+    }));
+
+    // Actually trigger the requests.
+    ASSERT_NO_THROW(runIOService());
+}
+
 
 }