]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[#1239] http client connection and interval timer are now Kea thread safe
authorRazvan Becheriu <razvan@isc.org>
Wed, 3 Jun 2020 18:35:06 +0000 (21:35 +0300)
committerRazvan Becheriu <razvan@isc.org>
Tue, 16 Jun 2020 09:02:51 +0000 (09:02 +0000)
src/lib/asiolink/interval_timer.cc
src/lib/http/client.cc

index fb38a43160693c4c9f02d51123544984ce2a1485..5a3061e6886b8130176b76ea8329b7f3d1582e72 100644 (file)
 
 #include <boost/bind.hpp>
 #include <boost/enable_shared_from_this.hpp>
+#include <boost/noncopyable.hpp>
 #include <boost/shared_ptr.hpp>
 
 #include <exceptions/exceptions.h>
 
 #include <atomic>
+#include <mutex>
+
+using namespace std;
 
 namespace isc {
 namespace asiolink {
@@ -28,46 +32,77 @@ namespace asiolink {
 /// Please follow the link to get an example:
 /// http://think-async.com/asio/asio-1.4.8/doc/asio/tutorial/tutdaytime3.html#asio.tutorial.tutdaytime3.the_tcp_connection_class
 class IntervalTimerImpl :
-    public boost::enable_shared_from_this<IntervalTimerImpl>
-{
-private:
-    // prohibit copy
-    IntervalTimerImpl(const IntervalTimerImpl& source);
-    IntervalTimerImpl& operator=(const IntervalTimerImpl& source);
+    public boost::enable_shared_from_this<IntervalTimerImpl>,
+    public boost::noncopyable {
 public:
+
+    /// @brief Constructor.
+    ///
+    /// @param io_service The IO service used to handle events.
     IntervalTimerImpl(IOService& io_service);
+
+    /// @brief Destructor.
     ~IntervalTimerImpl();
+
+    /// @brief Setup function to register callback and start timer.
+    ///
+    /// @param cbfunc The callback function registered on timer.
+    /// @param interval The interval used to start the timer.
+    /// @param interval_mode The interval mode used by the timer.
     void setup(const IntervalTimer::Callback& cbfunc, const long interval,
                const IntervalTimer::Mode& interval_mode
                = IntervalTimer::REPEATING);
+
+    /// @brief Callback function which calls the registerd callback.
+    ///
+    /// @param error The error code retrieved from the timer.
     void callback(const boost::system::error_code& error);
+
+    /// @brief Cancel timer.
     void cancel() {
+        lock_guard<mutex> lk (mutex_);
         timer_.cancel();
         interval_ = 0;
     }
+
+    /// @brief Get the timer interval.
+    ///
+    /// @return The timer interval.
     long getInterval() const { return (interval_); }
+
 private:
-    // a function to update timer_ when it expires
+
+    /// @brief Update function to update timer_ when it expires.
+    ///
+    /// Should be called in a thread safe context.
     void update();
-    // a function to call back when timer_ expires
+
+    /// @brief The callback function to call when timer_ expires.
     IntervalTimer::Callback cbfunc_;
-    // interval in milliseconds
+
+    /// @brief The interval in milliseconds.
     std::atomic<long> interval_;
-    // asio timer
+
+    /// @brief The asio timer.
     boost::asio::deadline_timer timer_;
 
-    // Controls how the timer behaves after expiration.
+    /// @brief Controls how the timer behaves after expiration.
     IntervalTimer::Mode mode_;
 
-    // interval_ will be set to this value in destructor in order to detect
-    // use-after-free type of bugs.
+    /// @brief Mutex to protect the internal state.
+    std::mutex mutex_;
+
+    /// @brief Invalid interval value.
+    ///
+    /// @ref interval_ will be set to this value in destructor in order to
+    /// detect use-after-free type of bugs.
     static const long INVALIDATED_INTERVAL = -1;
 };
 
 IntervalTimerImpl::IntervalTimerImpl(IOService& io_service) :
     interval_(0), timer_(io_service.get_io_service()),
-    mode_(IntervalTimer::REPEATING)
-{}
+    mode_(IntervalTimer::REPEATING) {
+}
 
 IntervalTimerImpl::~IntervalTimerImpl() {
     interval_ = INVALIDATED_INTERVAL;
@@ -76,8 +111,7 @@ IntervalTimerImpl::~IntervalTimerImpl() {
 void
 IntervalTimerImpl::setup(const IntervalTimer::Callback& cbfunc,
                          const long interval,
-                         const IntervalTimer::Mode& mode)
-{
+                         const IntervalTimer::Mode& mode) {
     // Interval should not be less than 0.
     if (interval < 0) {
         isc_throw(isc::BadValue, "Interval should not be less than or "
@@ -87,6 +121,8 @@ IntervalTimerImpl::setup(const IntervalTimer::Callback& cbfunc,
     if (cbfunc.empty()) {
         isc_throw(isc::InvalidParameter, "Callback function is empty");
     }
+
+    lock_guard<mutex> lk(mutex_);
     cbfunc_ = cbfunc;
     interval_ = interval;
     mode_ = mode;
@@ -111,19 +147,23 @@ IntervalTimerImpl::update() {
         isc_throw(isc::Unexpected, "Failed to update timer: " << e.what());
     } catch (const boost::bad_weak_ptr&) {
         // Can't happen. It means a severe internal bug.
-        assert(0);
     }
 }
 
 void
 IntervalTimerImpl::callback(const boost::system::error_code& ec) {
-    assert(interval_ != INVALIDATED_INTERVAL);
+    if (interval_ == INVALIDATED_INTERVAL) {
+        isc_throw(isc::BadValue, "Interval internal state");
+    }
     if (interval_ == 0 || ec) {
         // timer has been canceled. Do nothing.
     } else {
-        // If we should repeat, set next expire time.
-        if (mode_ == IntervalTimer::REPEATING) {
-            update();
+        {
+            lock_guard<mutex> lk(mutex_);
+            // If we should repeat, set next expire time.
+            if (mode_ == IntervalTimer::REPEATING) {
+                update();
+            }
         }
 
         // Invoke the call back function.
@@ -132,8 +172,8 @@ IntervalTimerImpl::callback(const boost::system::error_code& ec) {
 }
 
 IntervalTimer::IntervalTimer(IOService& io_service) :
-    impl_(new IntervalTimerImpl(io_service))
-{}
+    impl_(new IntervalTimerImpl(io_service)) {
+}
 
 IntervalTimer::~IntervalTimer() {
     // Cancel the timer to make sure cbfunc_() will not be called any more.
index bde27af3c2f64d65cfd9767850c71bb57740809c..a012175ea1b3dbe27edd4628e38ceb3315f20362 100644 (file)
@@ -96,19 +96,20 @@ typedef boost::shared_ptr<ConnectionPool> ConnectionPoolPtr;
 /// the new request is stored in the FIFO queue. The queued requests to the
 /// particular URL are sent to the server when the current transaction ends.
 ///
-/// The communication over the TCP socket is asynchronous. The caller is notified
-/// about the completion of the transaction via a callback that the caller supplies
-/// when initiating the transaction.
+/// The communication over the TCP socket is asynchronous. The caller is
+/// notified about the completion of the transaction via a callback that the
+/// caller supplies when initiating the transaction.
 class Connection : public boost::enable_shared_from_this<Connection> {
 public:
 
     /// @brief Constructor.
     ///
     /// @param io_service IO service to be used for the connection.
-    /// @param conn_pool Back pointer to the connection pool to which this connection
-    /// belongs.
+    /// @param conn_pool Back pointer to the connection pool to which this
+    /// connection belongs.
     /// @param url URL associated with this connection.
-    explicit Connection(IOService& io_service, const ConnectionPoolPtr& conn_pool,
+    explicit Connection(IOService& io_service,
+                        const ConnectionPoolPtr& conn_pool,
                         const Url& url);
 
     /// @brief Destructor.
@@ -119,19 +120,21 @@ public:
     /// This method expects that all pointers provided as argument are non-null.
     ///
     /// @param request Pointer to the request to be sent to the server.
-    /// @param response Pointer to the object into which the response is stored. The
-    /// caller should create a response object of the type which matches the content
-    /// type expected by the caller, e.g. HttpResponseJson when JSON content type
-    /// is expected to be received.
+    /// @param response Pointer to the object into which the response is stored.
+    /// The caller should create a response object of the type which matches the
+    /// content type expected by the caller, e.g. HttpResponseJson when JSON
+    /// content type is expected to be received.
     /// @param request_timeout Request timeout in milliseconds.
     /// @param callback Pointer to the callback function to be invoked when the
     /// transaction completes.
-    /// @param connect_callback Pointer to the callback function to be invoked when
-    /// the client connects to the server.
-    /// @param close_callback Pointer to the callback function to be invoked when
-    /// the client closes the socket to the server.
-    void doTransaction(const HttpRequestPtr& request, const HttpResponsePtr& response,
-                       const long request_timeout, const HttpClient::RequestHandler& callback,
+    /// @param connect_callback Pointer to the callback function to be invoked
+    /// when the client connects to the server.
+    /// @param close_callback Pointer to the callback function to be invoked
+    /// when the client closes the socket to the server.
+    void doTransaction(const HttpRequestPtr& request,
+                       const HttpResponsePtr& response,
+                       const long request_timeout,
+                       const HttpClient::RequestHandler& callback,
                        const HttpClient::ConnectHandler& connect_callback,
                        const HttpClient::CloseHandler& close_callback);
 
@@ -141,7 +144,7 @@ public:
     /// @brief Checks if a transaction has been initiated over this connection.
     ///
     /// @return true if transaction has been initiated, false otherwise.
-    bool isTransactionOngoing() const;
+    bool isTransactionOngoing();
 
     /// @brief Checks if a socket descriptor belongs to this connection.
     ///
@@ -169,8 +172,66 @@ public:
 
 private:
 
+    /// @brief Starts new asynchronous transaction (HTTP request and response).
+    ///
+    /// Should be called in a thread safe context.
+    ///
+    /// This method expects that all pointers provided as argument are non-null.
+    ///
+    /// @param request Pointer to the request to be sent to the server.
+    /// @param response Pointer to the object into which the response is stored.
+    /// The caller should create a response object of the type which matches the
+    /// content type expected by the caller, e.g. HttpResponseJson when JSON
+    /// content type is expected to be received.
+    /// @param request_timeout Request timeout in milliseconds.
+    /// @param callback Pointer to the callback function to be invoked when the
+    /// transaction completes.
+    /// @param connect_callback Pointer to the callback function to be invoked
+    /// when the client connects to the server.
+    /// @param close_callback Pointer to the callback function to be invoked
+    /// when the client closes the socket to the server.
+    void doTransactionInternal(const HttpRequestPtr& request,
+                               const HttpResponsePtr& response,
+                               const long request_timeout,
+                               const HttpClient::RequestHandler& callback,
+                               const HttpClient::ConnectHandler& connect_callback,
+                               const HttpClient::CloseHandler& close_callback);
+
+    /// @brief Closes the socket and cancels the request timer.
+    ///
+    /// Should be called in a thread safe context.
+    void closeInternal();
+
+    /// @brief Checks if a transaction has been initiated over this connection.
+    ///
+    /// Should be called in a thread safe context.
+    ///
+    /// @return true if transaction has been initiated, false otherwise.
+    bool isTransactionOngoingInternal() const;
+
+    /// @brief Checks and logs if premature transaction timeout is suspected.
+    ///
+    /// Should be called in a thread safe context.
+    ///
+    /// There are cases when the premature timeout occurs, e.g. as a result of
+    /// moving system clock, during the transaction. In such case, the
+    /// @c terminate function is called which resets the transaction state but
+    /// the transaction handlers may be already waiting for the execution.
+    /// Each such handler should call this function to check if the transaction
+    /// it is participating in is still alive. If it is not, it should simply
+    /// return. This method also logs such situation.
+    ///
+    /// @param transid identifier of the transaction for which the handler
+    /// is being invoked. It is compared against the current transaction
+    /// id for this connection.
+    ///
+    /// @return true if the premature timeout is suspected, false otherwise.
+    bool checkPrematureTimeoutInternal(const uint64_t transid);
+
     /// @brief Resets the state of the object.
     ///
+    /// Should be called in a thread safe context.
+    ///
     /// In particular, it removes instances of objects provided for the previous
     /// transaction by a caller. It doesn't close the socket, though.
     void resetState();
@@ -187,6 +248,38 @@ private:
     void terminate(const boost::system::error_code& ec,
                    const std::string& parsing_error = "");
 
+    /// @brief Performs tasks required after receiving a response or after an
+    /// error.
+    ///
+    /// Should be called in a thread safe context.
+    ///
+    /// This method triggers user's callback, resets the state of the connection
+    /// and initiates next transaction if there is any transaction queued for the
+    /// URL associated with this connection.
+    ///
+    /// @param ec Error code received as a result of the IO operation.
+    /// @param parsing_error Message parsing error.
+    void terminateInternal(const boost::system::error_code& ec,
+                           const std::string& parsing_error = "");
+
+    /// @brief Run parser and check if more data is needed.
+    ///
+    /// @param ec Error code received as a result of the IO operation.
+    /// @param length Number of bytes received.
+    ///
+    /// @return true if more data is needed, false otherwise.
+    bool runParser(const boost::system::error_code& ec, size_t length);
+
+    /// @brief Run parser and check if more data is needed.
+    ///
+    /// Should be called in a thread safe context.
+    ///
+    /// @param ec Error code received as a result of the IO operation.
+    /// @param length Number of bytes received.
+    ///
+    /// @return true if more data is needed, false otherwise.
+    bool runParserInternal(const boost::system::error_code& ec, size_t length);
+
     /// @brief This method schedules timer or reschedules existing timer.
     ///
     /// @param request_timeout New timer interval in milliseconds.
@@ -293,6 +386,9 @@ private:
 
     /// @brief User supplied callback.
     HttpClient::CloseHandler close_callback_;
+
+    /// @brief Mutex to protect the internal state.
+    std::mutex mutex_;
 };
 
 /// @brief Shared pointer to the connection.
@@ -709,7 +805,6 @@ Connection::closeCallback(const bool clear) {
     }
 }
 
-
 void
 Connection::doTransaction(const HttpRequestPtr& request,
                           const HttpResponsePtr& response,
@@ -717,6 +812,23 @@ Connection::doTransaction(const HttpRequestPtr& request,
                           const HttpClient::RequestHandler& callback,
                           const HttpClient::ConnectHandler& connect_callback,
                           const HttpClient::CloseHandler& close_callback) {
+    if (MultiThreadingMgr::instance().getMode()) {
+        std::lock_guard<std::mutex> lk(mutex_);
+        doTransactionInternal(request, response, request_timeout,
+                              callback, connect_callback, close_callback);
+    } else {
+        doTransactionInternal(request, response, request_timeout,
+                              callback, connect_callback, close_callback);
+    }
+}
+
+void
+Connection::doTransactionInternal(const HttpRequestPtr& request,
+                                  const HttpResponsePtr& response,
+                                  const long request_timeout,
+                                  const HttpClient::RequestHandler& callback,
+                                  const HttpClient::ConnectHandler& connect_callback,
+                                  const HttpClient::CloseHandler& close_callback) {
     try {
         current_request_ = request;
         current_response_ = response;
@@ -774,16 +886,37 @@ Connection::doTransaction(const HttpRequestPtr& request,
 
 void
 Connection::close() {
+    if (MultiThreadingMgr::instance().getMode()) {
+        std::lock_guard<std::mutex> lk(mutex_);
+        return (closeInternal());
+    } else {
+        return (closeInternal());
+    }
+}
+
+void
+Connection::closeInternal() {
     // Pass in true to discard the callback.
     closeCallback(true);
 
     timer_.cancel();
     socket_.close();
+
     resetState();
 }
 
 bool
-Connection::isTransactionOngoing() const {
+Connection::isTransactionOngoing() {
+    if (MultiThreadingMgr::instance().getMode()) {
+        std::lock_guard<std::mutex> lk(mutex_);
+        return (isTransactionOngoingInternal());
+    } else {
+        return (isTransactionOngoingInternal());
+    }
+}
+
+bool
+Connection::isTransactionOngoingInternal() const {
     return (static_cast<bool>(current_request_));
 }
 
@@ -794,12 +927,22 @@ Connection::isMySocket(int socket_fd) const {
 
 bool
 Connection::checkPrematureTimeout(const uint64_t transid) {
+    if (MultiThreadingMgr::instance().getMode()) {
+        std::lock_guard<std::mutex> lk(mutex_);
+        return (checkPrematureTimeoutInternal(transid));
+    } else {
+        return (checkPrematureTimeoutInternal(transid));
+    }
+}
+
+bool
+Connection::checkPrematureTimeoutInternal(const uint64_t transid) {
     // If there is no transaction but the handlers are invoked it means
     // that the last transaction in the queue timed out prematurely.
     // Also, if there is a transaction in progress but the ID of that
     // transaction doesn't match the one associated with the handler it,
     // also means that the transaction timed out prematurely.
-    if (!isTransactionOngoing() || (transid != current_transid_)) {
+    if (!isTransactionOngoingInternal() || (transid != current_transid_)) {
         LOG_WARN(http_logger, HTTP_PREMATURE_CONNECTION_TIMEOUT_OCCURRED);
         return (true);
     }
@@ -809,10 +952,20 @@ Connection::checkPrematureTimeout(const uint64_t transid) {
 void
 Connection::terminate(const boost::system::error_code& ec,
                       const std::string& parsing_error) {
+    if (MultiThreadingMgr::instance().getMode()) {
+        std::lock_guard<std::mutex> lk(mutex_);
+        terminateInternal(ec, parsing_error);
+    } else {
+        terminateInternal(ec, parsing_error);
+    }
+}
 
+void
+Connection::terminateInternal(const boost::system::error_code& ec,
+                              const std::string& parsing_error) {
     HttpResponsePtr response;
 
-    if (isTransactionOngoing()) {
+    if (isTransactionOngoingInternal()) {
 
         timer_.cancel();
         socket_.cancel();
@@ -827,11 +980,13 @@ Connection::terminate(const boost::system::error_code& ec,
             LOG_DEBUG(http_logger, isc::log::DBGLVL_TRACE_BASIC_DATA,
                       HTTP_SERVER_RESPONSE_RECEIVED_DETAILS)
                 .arg(url_.toText())
-                .arg((parser_ ? parser_->getBufferAsString(MAX_LOGGED_MESSAGE_SIZE)
-                      : "[HttpResponseParser is null]"));
+                .arg((parser_ ?
+                      parser_->getBufferAsString(MAX_LOGGED_MESSAGE_SIZE) :
+                      "[HttpResponseParser is null]"));
 
         } else {
-            std::string err = parsing_error.empty() ? ec.message() : parsing_error;
+            std::string err = parsing_error.empty() ? ec.message() :
+                                                      parsing_error;
 
             LOG_DEBUG(http_logger, isc::log::DBGLVL_TRACE_BASIC,
                       HTTP_BAD_SERVER_RESPONSE_RECEIVED)
@@ -844,11 +999,13 @@ Connection::terminate(const boost::system::error_code& ec,
                 LOG_DEBUG(http_logger, isc::log::DBGLVL_TRACE_BASIC_DATA,
                           HTTP_BAD_SERVER_RESPONSE_RECEIVED_DETAILS)
                     .arg(url_.toText())
-                    .arg((parser_ ? parser_->getBufferAsString()
-                          : "[HttpResponseParser is null]"));
+                    .arg((parser_ ? parser_->getBufferAsString() :
+                                    "[HttpResponseParser is null]"));
             }
         }
 
+        // unlock mutex so that callback can call any locking function.
+        mutex_.unlock();
         try {
             // The callback should take care of its own exceptions but one
             // never knows.
@@ -856,11 +1013,13 @@ Connection::terminate(const boost::system::error_code& ec,
 
         } catch (...) {
         }
+        // lock mutex so that we can continue processing.
+        mutex_.lock();
 
         // If we're not requesting connection persistence, we should close the socket.
         // We're going to reconnect for the next transaction.
         if (!current_request_->isPersistent()) {
-            close();
+            closeInternal();
         }
 
         resetState();
@@ -874,10 +1033,12 @@ Connection::terminate(const boost::system::error_code& ec,
     HttpClient::ConnectHandler connect_callback;
     HttpClient::CloseHandler close_callback;
     ConnectionPoolPtr conn_pool = conn_pool_.lock();
-    if (conn_pool && conn_pool->getNextRequest(url_, request, response, request_timeout,
-                                               callback, connect_callback, close_callback)) {
-        doTransaction(request, response, request_timeout, callback,
-                      connect_callback, close_callback);
+    if (conn_pool && conn_pool->getNextRequest(url_, request, response,
+                                               request_timeout, callback,
+                                               connect_callback,
+                                               close_callback)) {
+        doTransactionInternal(request, response, request_timeout, callback,
+                              connect_callback, close_callback);
     }
 }
 
@@ -951,7 +1112,8 @@ Connection::connectCallback(HttpClient::ConnectHandler connect_callback,
 }
 
 void
-Connection::sendCallback(const uint64_t transid, const boost::system::error_code& ec,
+Connection::sendCallback(const uint64_t transid,
+                         const boost::system::error_code& ec,
                          size_t length) {
     if (checkPrematureTimeout(transid)) {
         return;
@@ -994,7 +1156,8 @@ Connection::sendCallback(const uint64_t transid, const boost::system::error_code
 }
 
 void
-Connection::receiveCallback(const uint64_t transid, const boost::system::error_code& ec,
+Connection::receiveCallback(const uint64_t transid,
+                            const boost::system::error_code& ec,
                             size_t length) {
     if (checkPrematureTimeout(transid)) {
         return;
@@ -1021,6 +1184,24 @@ Connection::receiveCallback(const uint64_t transid, const boost::system::error_c
     // Receiving is in progress, so push back the timeout.
     scheduleTimer(timer_.getInterval());
 
+    if (runParser(ec, length)) {
+        doReceive(transid);
+    }
+}
+
+bool
+Connection::runParser(const boost::system::error_code& ec, size_t length) {
+    if (MultiThreadingMgr::instance().getMode()) {
+        std::lock_guard<std::mutex> lk(mutex_);
+        return (runParserInternal(ec, length));
+    } else {
+        return (runParserInternal(ec, length));
+    }
+}
+
+bool
+Connection::runParserInternal(const boost::system::error_code& ec,
+                              size_t length) {
     // If we have received any data, let's feed the parser with it.
     if (length != 0) {
         parser_->postBuffer(static_cast<void*>(input_buf_.data()), length);
@@ -1029,25 +1210,27 @@ Connection::receiveCallback(const uint64_t transid, const boost::system::error_c
 
     // If the parser still needs data, let's schedule another receive.
     if (parser_->needData()) {
-        doReceive(transid);
+        return (true);
 
     } else if (parser_->httpParseOk()) {
         // No more data needed and parsing has been successful so far. Let's
         // try to finalize the response parsing.
         try {
             current_response_->finalize();
-            terminate(ec);
+            terminateInternal(ec);
 
         } catch (const std::exception& ex) {
             // If there is an error here, we need to return the error message.
-            terminate(ec, ex.what());
+            terminateInternal(ec, ex.what());
         }
 
     } else {
-        // Parsing was unsuccessul. Let's pass the error message held in the
+        // Parsing was unsuccessful. Let's pass the error message held in the
         // parser.
-        terminate(ec, parser_->getErrorMessage());
+        terminateInternal(ec, parser_->getErrorMessage());
     }
+
+    return (false);
 }
 
 void