]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[#4274] Checkpoint: split Exchange
authorFrancis Dupont <fdupont@isc.org>
Fri, 23 Jan 2026 14:54:21 +0000 (15:54 +0100)
committerFrancis Dupont <fdupont@isc.org>
Mon, 9 Feb 2026 21:05:46 +0000 (22:05 +0100)
src/hooks/dhcp/radius/client_exchange.cc
src/hooks/dhcp/radius/client_exchange.h
src/hooks/dhcp/radius/tests/exchange_unittests.cc

index f9d711e06a78c7e929fc7ead40aa7a2c761340e9..4fbd8076b27c0aaffc2adec081459ea6481e0760 100644 (file)
@@ -61,12 +61,8 @@ Exchange::Exchange(const asiolink::IOServicePtr io_service,
                    const Servers& servers,
                    Handler handler)
     : identifier_(""), io_service_(io_service), sync_(false),
-      started_(false), terminated_(false), rc_(ERROR_RC),
-      start_time_(std::chrono::steady_clock().now()),
-      socket_(), ep_(), timer_(), server_(), idx_(0),
-      request_(request), sent_(), received_(), buffer_(), size_(0),
-      retries_(0), maxretries_(maxretries), servers_(servers),
-      postponed_(), handler_(handler), mutex_(new std::mutex()) {
+      rc_(ERROR_RC), request_(request), sent_(), received_(),
+      maxretries_(maxretries), servers_(servers), handler_(handler) {
     if (!io_service) {
         isc_throw(BadValue, "null IO service");
     }
@@ -82,27 +78,12 @@ Exchange::Exchange(const asiolink::IOServicePtr io_service,
     createIdentifier();
 }
 
-ExchangePtr
-Exchange::create(const asiolink::IOServicePtr io_service,
-                 const MessagePtr& request,
-                 unsigned maxretries,
-                 const Servers& servers,
-                 Handler handler) {
-    return (ExchangePtr(new Exchange(io_service, request, maxretries, servers,
-                                     handler)));
-}
-
 Exchange::Exchange(const MessagePtr& request,
                    unsigned maxretries,
                    const Servers& servers)
     : identifier_(""), io_service_(new IOService()), sync_(true),
-      started_(false), terminated_(false), rc_(ERROR_RC),
-      start_time_(std::chrono::steady_clock().now()),
-      socket_(), ep_(), timer_(), server_(), idx_(0),
-      request_(request), sent_(), received_(), buffer_(), size_(0),
-      retries_(0), maxretries_(maxretries), servers_(servers), postponed_(),
-      handler_(), mutex_(new std::mutex()) {
-
+      rc_(ERROR_RC), request_(request), sent_(), received_(),
+      maxretries_(maxretries), servers_(servers), handler_() {
     if (!request) {
         isc_throw(BadValue, "null request");
     }
@@ -112,14 +93,48 @@ Exchange::Exchange(const MessagePtr& request,
     createIdentifier();
 }
 
+ExchangePtr
+Exchange::create(const asiolink::IOServicePtr io_service,
+                 const MessagePtr& request,
+                 unsigned maxretries,
+                 const Servers& servers,
+                 Handler handler) {
+    return (UdpExchangePtr(new UdpExchange(io_service, request, maxretries,
+                                           servers, handler)));
+}
+
 ExchangePtr
 Exchange::create(const MessagePtr& request,
-                   unsigned maxretries,
+                 unsigned maxretries,
                  const Servers& servers) {
-    return (ExchangePtr(new Exchange(request, maxretries, servers)));
+    return (UdpExchangePtr(new UdpExchange(request, maxretries, servers)));
+}
+
+UdpExchange::UdpExchange(const asiolink::IOServicePtr io_service,
+                         const MessagePtr& request,
+                         unsigned maxretries,
+                         const Servers& servers,
+                         Handler handler)
+    : Exchange(io_service, request, maxretries, servers, handler),
+      started_(false), terminated_(false),
+      start_time_(std::chrono::steady_clock().now()),
+      socket_(), ep_(), timer_(), server_(), idx_(0),
+      buffer_(), size_(0), retries_(0), postponed_(),
+      mutex_(new std::mutex()) {
 }
 
-Exchange::~Exchange() {
+UdpExchange::UdpExchange(const MessagePtr& request,
+                         unsigned maxretries,
+                         const Servers& servers)
+    : Exchange(request, maxretries, servers),
+      started_(false), terminated_(false),
+      start_time_(std::chrono::steady_clock().now()),
+      socket_(), ep_(), timer_(), server_(), idx_(0),
+      buffer_(), size_(0), retries_(0), postponed_(),
+      mutex_(new std::mutex()) {
+}
+
+UdpExchange::~UdpExchange() {
     MultiThreadingLock lock(*mutex_);
     shutdownInternal();
     timer_.reset();
@@ -165,7 +180,7 @@ Exchange::logReplyMessages() const {
 }
 
 void
-Exchange::start() {
+UdpExchange::start() {
     MultiThreadingLock lock(*mutex_);
 
     if (started_) {
@@ -192,14 +207,14 @@ Exchange::start() {
 }
 
 void
-Exchange::shutdown() {
+UdpExchange::shutdown() {
     // Avoid multiple terminations.
     MultiThreadingLock lock(*mutex_);
     shutdownInternal();
 }
 
 void
-Exchange::shutdownInternal() {
+UdpExchange::shutdownInternal() {
     if (terminated_) {
         return;
     } else {
@@ -222,8 +237,9 @@ Exchange::shutdownInternal() {
 }
 
 void
-Exchange::buildRequest() {
-    if (!server_) {
+Exchange::buildRequest(const ServerPtr& server,
+                       std::chrono::steady_clock::time_point start_time) {
+    if (!server) {
         isc_throw(Unexpected, "no server");
     }
 
@@ -242,7 +258,7 @@ Exchange::buildRequest() {
     }
 
     // Set the secret.
-    sent_->setSecret(server_->getSecret());
+    sent_->setSecret(server->getSecret());
 
     // Get attributes.
     AttributesPtr attrs = sent_->getAttributes();
@@ -254,14 +270,14 @@ Exchange::buildRequest() {
     // Add Acct-Delay-Time to Accounting-Request message.
     if ((sent_->getCode() == PW_ACCOUNTING_REQUEST) &&
         (attrs->count(PW_ACCT_DELAY_TIME) == 0)) {
-        auto delta = steady_clock().now() - start_time_;
+        auto delta = steady_clock().now() - start_time;
         seconds secs = duration_cast<seconds>(delta);
         attrs->add(Attribute::fromInt(PW_ACCT_DELAY_TIME,
                                       static_cast<uint32_t>(secs.count())));
     }
 
     // Add NAS-IP[v6]-Address with the local address.
-    IOAddress local_addr = server_->getLocalAddress();
+    IOAddress local_addr = server->getLocalAddress();
     short family = local_addr.getFamily();
     if (family == AF_INET) {
         if (attrs->count(PW_NAS_IP_ADDRESS) == 0) {
@@ -288,7 +304,12 @@ Exchange::buildRequest() {
 }
 
 void
-Exchange::open() {
+UdpExchange::buildRequest() {
+    Exchange::buildRequest(server_, start_time_);
+}
+
+void
+UdpExchange::open() {
     if (RadiusImpl::shutdown_) {
         shutdownInternal();
         return;
@@ -316,14 +337,14 @@ Exchange::open() {
                 (server_->getDeadtimeEnd() > start_time_)) {
                 postponed_.push_back(idx_);
                 ++idx_;
-                io_service_->post(std::bind(&Exchange::openNext,
+                io_service_->post(std::bind(&UdpExchange::openNext,
                                             shared_from_this()));
                 return;
             }
         } else {
             // Second pass: try postponed servers.
             if (postponed_.empty()) {
-                io_service_->post(std::bind(&Exchange::terminate,
+                io_service_->post(std::bind(&UdpExchange::terminate,
                                             shared_from_this()));
                 return;
             }
@@ -382,7 +403,7 @@ Exchange::open() {
                 .arg(ep_->getPort());
 
             socket_->asyncSend(&buffer_[0], buffer_.size(), ep_.get(),
-                               std::bind(&Exchange::sentHandler,
+                               std::bind(&UdpExchange::sentHandler,
                                          shared_from_this(),
                                          ph::_1,   // error_code.
                                          ph::_2)); // size.
@@ -397,7 +418,7 @@ Exchange::open() {
                 socket_->close();
                 socket_.reset();
             }
-            io_service_->post(std::bind(&Exchange::openNext,
+            io_service_->post(std::bind(&UdpExchange::openNext,
                                         shared_from_this()));
             return;
         }
@@ -419,7 +440,7 @@ Exchange::open() {
         if (idx_ == servers_.size()) {
             // Postponed servers are exhausted.
             if (postponed_.size() < 2) {
-                io_service_->post(std::bind(&Exchange::terminate,
+                io_service_->post(std::bind(&UdpExchange::terminate,
                                             shared_from_this()));
                 return;
             }
@@ -429,13 +450,13 @@ Exchange::open() {
             // Try next server.
             ++idx_;
             if ((idx_ == servers_.size()) && (postponed_.empty())) {
-                io_service_->post(std::bind(&Exchange::terminate,
+                io_service_->post(std::bind(&UdpExchange::terminate,
                                             shared_from_this()));
                 return;
             }
         }
         // Call again open to try the next server.
-        io_service_->post(std::bind(&Exchange::openNext,
+        io_service_->post(std::bind(&UdpExchange::openNext,
                                     shared_from_this()));
         return;
     }
@@ -476,12 +497,12 @@ Exchange::open() {
             .arg(retries_);
 
         socket_->asyncSend(&buffer_[0],
-                                buffer_.size(),
-                                ep_.get(),
-                                std::bind(&Exchange::sentHandler,
-                                          shared_from_this(),
-                                          ph::_1,   // error_code.
-                                          ph::_2)); // size.
+                           buffer_.size(),
+                           ep_.get(),
+                           std::bind(&UdpExchange::sentHandler,
+                                     shared_from_this(),
+                                     ph::_1,   // error_code.
+                                     ph::_2)); // size.
         return;
     } catch (const Exception& exc) {
         LOG_ERROR(radius_logger, RADIUS_EXCHANGE_OPEN_FAILED)
@@ -493,16 +514,16 @@ Exchange::open() {
             socket_->close();
             socket_.reset();
         }
-        io_service_->post(std::bind(&Exchange::openNext,
+        io_service_->post(std::bind(&UdpExchange::openNext,
                                     shared_from_this()));
         return;
     }
 }
 
 void
-Exchange::sentHandler(ExchangePtr ex,
-                      const boost::system::error_code ec,
-                      const size_t size) {
+UdpExchange::sentHandler(UdpExchangePtr ex,
+                         const boost::system::error_code ec,
+                         const size_t size) {
     if (!ex) {
         isc_throw(Unexpected, "null exchange in sentHandler");
     }
@@ -528,7 +549,7 @@ Exchange::sentHandler(ExchangePtr ex,
             ex->socket_->close();
             ex->socket_.reset();
         }
-        ex->io_service_->post(std::bind(&Exchange::openNext, ex));
+        ex->io_service_->post(std::bind(&UdpExchange::openNext, ex));
         return;
     }
 
@@ -540,15 +561,16 @@ Exchange::sentHandler(ExchangePtr ex,
     ex->buffer_.resize(BUF_LEN);
     ex->size_ = ex->buffer_.size();
     ex->socket_->asyncReceive(&(ex->buffer_)[0], ex->size_, 0, ex->ep_.get(),
-                              std::bind(&Exchange::receivedHandler, ex,
+                              std::bind(&UdpExchange::receivedHandler,
+                                        ex,
                                         ph::_1,   // error_code.
                                         ph::_2)); // size.
 }
 
 void
-Exchange::receivedHandler(ExchangePtr ex,
-                          const boost::system::error_code ec,
-                          const size_t size) {
+UdpExchange::receivedHandler(UdpExchangePtr ex,
+                             const boost::system::error_code ec,
+                             const size_t size) {
     if (!ex) {
         isc_throw(Unexpected, "null exchange in receivedHandler");
     }
@@ -576,7 +598,7 @@ Exchange::receivedHandler(ExchangePtr ex,
         LOG_ERROR(radius_logger, RADIUS_EXCHANGE_RECEIVE_FAILED)
             .arg(ex->identifier_)
             .arg(ec.message());
-        ex->io_service_->post(std::bind(&Exchange::openNext, ex));
+        ex->io_service_->post(std::bind(&UdpExchange::openNext, ex));
         return;
     }
 
@@ -675,15 +697,15 @@ Exchange::receivedHandler(ExchangePtr ex,
 
     // If bad then retry, if not including reject it is done.
     if ((ex->rc_ != OK_RC) && (ex->rc_ != REJECT_RC)) {
-        ex->io_service_->post(std::bind(&Exchange::openNext, ex));
+        ex->io_service_->post(std::bind(&UdpExchange::openNext, ex));
     } else {
         ex->logReplyMessages();
-        ex->io_service_->post(std::bind(&Exchange::terminate, ex));
+        ex->io_service_->post(std::bind(&UdpExchange::terminate, ex));
     }
 }
 
 void
-Exchange::terminate() {
+UdpExchange::terminate() {
     // Avoid multiple terminations.
     MultiThreadingLock lock(*mutex_);
 
@@ -733,15 +755,15 @@ Exchange::terminate() {
 }
 
 void
-Exchange::setTimer() {
+UdpExchange::setTimer() {
     cancelTimer();
     timer_.reset(new IntervalTimer(io_service_));
-    timer_->setup(std::bind(&Exchange::timeoutHandler, shared_from_this()),
+    timer_->setup(std::bind(&UdpExchange::timeoutHandler, shared_from_this()),
                   server_->getTimeout() * 1000, IntervalTimer::ONE_SHOT);
 }
 
 void
-Exchange::cancelTimer() {
+UdpExchange::cancelTimer() {
     if (timer_) {
         timer_->cancel();
         timer_.reset();
@@ -749,7 +771,7 @@ Exchange::cancelTimer() {
 }
 
 void
-Exchange::timeoutHandler(ExchangePtr ex) {
+UdpExchange::timeoutHandler(UdpExchangePtr ex) {
     MultiThreadingLock lock(*ex->mutex_);
     LOG_ERROR(radius_logger, RADIUS_EXCHANGE_TIMEOUT)
         .arg(ex->identifier_);
index 0c1d71173d70c21b38190a8dc88dfbd57f823c64..4b7b2ab2b015eb2fff1e83066d985a7bd712c443 100644 (file)
@@ -54,8 +54,8 @@ class Exchange;
 /// @brief Type of shared pointers to RADIUS exchange object.
 typedef boost::shared_ptr<Exchange> ExchangePtr;
 
-/// @brief RADIUS Exchange.
-class Exchange : public boost::enable_shared_from_this<Exchange> {
+/// @brief RADIUS Base Exchange.
+class Exchange {
 public:
     /// @brief Receive buffer size.
     static constexpr size_t BUF_LEN = 8192;
@@ -97,7 +97,7 @@ public:
                               const Servers& servers);
 
     /// @brief Destructor.
-    virtual ~Exchange();
+    virtual ~Exchange() = default;
 
     /// @brief Get identifier.
     ///
@@ -131,10 +131,10 @@ public:
     void logReplyMessages() const;
 
     /// @brief Start.
-    virtual void start();
+    virtual void start() = 0;
 
     /// @brief Shutdown.
-    virtual void shutdown();
+    virtual void shutdown() = 0;
 
 protected:
     /// @brief Constructor.
@@ -172,15 +172,91 @@ protected:
     /// @brief Sync / async flag.
     bool sync_;
 
+    /// @brief Error/return code.
+    int rc_;
+
+    /// @brief Request message.
+    MessagePtr request_;
+
+    /// @brief Sent message.
+    MessagePtr sent_;
+
+    /// @brief Received message.
+    MessagePtr received_;
+
+    /// @brief Maximum number of retries for a server.
+    /// @note 0 is a valid value which means no retry.
+    unsigned maxretries_;
+
+    /// @brief Servers (a copy which is what we need).
+    Servers servers_;
+
+    /// @brief Termination handler.
+    Handler handler_;
+
+    /// @brief Create identifier.
+    void createIdentifier();
+
+    /// @brief Build request.
+    ///
+    /// @param server Server where to send the request.
+    /// @param start_time Start time of the exchange.
+    void buildRequest(const ServerPtr& server,
+                      std::chrono::steady_clock::time_point start_time);
+};
+
+/// @brief RADIUS/UDP exchange (forward declaration).
+class UdpExchange;
+
+/// @brief Type of shared pointers to RADIUS/UDP exchange object.
+typedef boost::shared_ptr<UdpExchange> UdpExchangePtr;
+
+/// @brief RADIUS/UDP Exchange.
+class UdpExchange : public Exchange,
+                    public boost::enable_shared_from_this<UdpExchange> {
+public:
+    /// @brief Constructor.
+    ///
+    /// Async version.
+    ///
+    /// @param io_service Reference to the IO service.
+    /// @param request request message to send.
+    /// @param maxretries maximum number of retries for a server.
+    /// @param servers Servers.
+    /// @param handler Termination handler.
+    UdpExchange(const asiolink::IOServicePtr io_service,
+                const MessagePtr& request,
+                unsigned maxretries,
+                const Servers& servers,
+                Handler handler);
+
+    /// @brief Constructor.
+    ///
+    /// Sync version.
+    ///
+    /// @param request request message to send.
+    /// @param maxretries maximum number of retries for a server.
+    /// @param servers Servers.
+    UdpExchange(const MessagePtr& request,
+                unsigned maxretries,
+                const Servers& servers);
+
+    /// @brief Destructor.
+    virtual ~UdpExchange();
+
+    /// @brief Start.
+    virtual void start();
+
+    /// @brief Shutdown.
+    virtual void shutdown();
+
+protected:
     /// @brief Started flag.
     bool started_;
 
     /// @brief Terminated flag.
     bool terminated_;
 
-    /// @brief Error/return code.
-    int rc_;
-
     /// @brief Start time.
     std::chrono::steady_clock::time_point start_time_;
 
@@ -202,15 +278,6 @@ protected:
     /// or when greater than the table size the first postponed server.
     size_t idx_;
 
-    /// @brief Request message.
-    MessagePtr request_;
-
-    /// @brief Sent message.
-    MessagePtr sent_;
-
-    /// @brief Received message.
-    MessagePtr received_;
-
     /// @brief Buffer.
     std::vector<uint8_t> buffer_;
 
@@ -220,25 +287,12 @@ protected:
     /// @brief Retry counter.
     unsigned retries_;
 
-    /// @brief Maximum number of retries for a server.
-    /// @note 0 is a valid value which means no retry.
-    unsigned maxretries_;
-
-    /// @brief Servers (a copy which is what we need).
-    Servers servers_;
-
     /// @brief List of postponed server indexes.
     std::list<size_t> postponed_;
 
-    /// @brief Termination handler.
-    Handler handler_;
-
     /// @brief State change mutex.
     boost::scoped_ptr<std::mutex> mutex_;
 
-    /// @brief Create identifier.
-    void createIdentifier();
-
     /// @brief Build request.
     void buildRequest();
 
@@ -251,7 +305,7 @@ protected:
     /// @brief Class open / open next.
     ///
     /// @param ex the exchange.
-    static void openNext(ExchangePtr ex) {
+    static void openNext(UdpExchangePtr ex) {
         ex->open();
     }
 
@@ -260,7 +314,7 @@ protected:
     /// @param ex the exchange.
     /// @param ec Boost ASIO error code.
     /// @param size number of sent octets.
-    static void sentHandler(ExchangePtr ex,
+    static void sentHandler(UdpExchangePtr ex,
                             const boost::system::error_code ec,
                             const size_t size);
 
@@ -269,7 +323,7 @@ protected:
     /// @param ex the exchange.
     /// @param ec Boost ASIO error code.
     /// @param size number of received octets.
-    static void receivedHandler(ExchangePtr ex,
+    static void receivedHandler(UdpExchangePtr ex,
                                 const boost::system::error_code ec,
                                 const size_t size);
 
@@ -282,7 +336,7 @@ protected:
     /// @brief Timeout handler.
     ///
     /// @param ex the exchange.
-    static void timeoutHandler(ExchangePtr ex);
+    static void timeoutHandler(UdpExchangePtr ex);
 
     /// @brief Terminate.
     void terminate();
index b8a5702c35518cbb855b24a6d5f9bdf5abec72ff..8ea6eb1fcb9e9e40cbc319fed602e4dc509173fc 100644 (file)
@@ -127,7 +127,7 @@ TEST(TestExchange, sync) {
 }
 
 /// Test Exchange class.
-class TestExchange : public Exchange {
+class TestExchange : public UdpExchange {
 public:
     /// Constructor.
     ///
@@ -136,32 +136,32 @@ public:
                  unsigned maxretries,
                  const Servers& servers,
                  Exchange::Handler handler)
-        : Exchange(io_service, request, maxretries, servers, handler) {
+        : UdpExchange(io_service, request, maxretries, servers, handler) {
     }
 
     /// Visible members.
     using Exchange::identifier_;
     using Exchange::sync_;
-    using Exchange::started_;
-    using Exchange::terminated_;
     using Exchange::rc_;
-    using Exchange::start_time_;
-    using Exchange::socket_;
-    using Exchange::ep_;
-    using Exchange::timer_;
-    using Exchange::server_;
-    using Exchange::idx_;
-    using Exchange::sent_;
     using Exchange::received_;
-    using Exchange::buffer_;
-    using Exchange::size_;
-    using Exchange::retries_;
-    using Exchange::postponed_;
+    using UdpExchange::started_;
+    using UdpExchange::terminated_;
+    using UdpExchange::start_time_;
+    using UdpExchange::socket_;
+    using UdpExchange::ep_;
+    using UdpExchange::timer_;
+    using UdpExchange::server_;
+    using UdpExchange::idx_;
+    using UdpExchange::sent_;
+    using UdpExchange::buffer_;
+    using UdpExchange::size_;
+    using UdpExchange::retries_;
+    using UdpExchange::postponed_;
 
     /// Visible methods.
-    using Exchange::buildRequest;
-    using Exchange::open;
-    using Exchange::receivedHandler;
+    using UdpExchange::buildRequest;
+    using UdpExchange::open;
+    using UdpExchange::receivedHandler;
 };
 
 /// Type of shared pointers to test exchange objets.
@@ -798,7 +798,8 @@ TEST_F(ExchangeTest, openRetryError) {
 // Verify receivedHandler with null exchange.
 TEST_F(ExchangeTest, receivedHandlerNull) {
     auto no_error = boost::system::error_code();
-    EXPECT_THROW_MSG(TestExchange::receivedHandler(ExchangePtr(), no_error, 0), Unexpected,
+    EXPECT_THROW_MSG(TestExchange::receivedHandler(TestExchangePtr(), no_error, 0),
+                     Unexpected,
                      "null exchange in receivedHandler");
 }