]> git.ipfire.org Git - thirdparty/kea.git/commitdiff
[#3298] Made MemHostDataSource MT safe
authorFrancis Dupont <fdupont@isc.org>
Tue, 2 Apr 2024 13:27:26 +0000 (15:27 +0200)
committerFrancis Dupont <fdupont@isc.org>
Thu, 4 Apr 2024 14:50:03 +0000 (16:50 +0200)
src/lib/dhcpsrv/testutils/memory_host_data_source.cc
src/lib/dhcpsrv/testutils/memory_host_data_source.h

index b354eeba94e9523832e63b0af176d6431e7c8c83..9767fdb478e1dc1b16b335c4972624e038067fe6 100644 (file)
@@ -7,9 +7,11 @@
 #include <config.h>
 
 #include <dhcpsrv/testutils/memory_host_data_source.h>
+#include <util/multi_threading_mgr.h>
 #include <boost/foreach.hpp>
 
 using namespace isc::db;
+using namespace isc::util;
 using namespace std;
 
 namespace isc {
@@ -22,6 +24,7 @@ MemHostDataSource::getAll(const Host::IdentifierType& identifier_type,
                           const size_t identifier_len) const {
     vector<uint8_t> ident(identifier_begin, identifier_begin + identifier_len);
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // If identifier type do not match, it's not for us
         if (h->getIdentifierType() != identifier_type) {
@@ -38,6 +41,7 @@ MemHostDataSource::getAll(const Host::IdentifierType& identifier_type,
 ConstHostCollection
 MemHostDataSource::getAll4(const SubnetID& subnet_id) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Keep it when subnet_id matches.
         if (h->getIPv4SubnetID() == subnet_id) {
@@ -50,6 +54,7 @@ MemHostDataSource::getAll4(const SubnetID& subnet_id) const {
 ConstHostCollection
 MemHostDataSource::getAll6(const SubnetID& subnet_id) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Keep it when subnet_id matches.
         if (h->getIPv6SubnetID() == subnet_id) {
@@ -62,6 +67,7 @@ MemHostDataSource::getAll6(const SubnetID& subnet_id) const {
 ConstHostCollection
 MemHostDataSource::getAllbyHostname(const std::string& hostname) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Keep it when hostname matches.
         if (h->getLowerHostname() == hostname) {
@@ -75,6 +81,7 @@ ConstHostCollection
 MemHostDataSource::getAllbyHostname4(const std::string& hostname,
                                      const SubnetID& subnet_id) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Keep it when hostname and subnet_id match.
         if ((h->getLowerHostname() == hostname) &&
@@ -89,6 +96,7 @@ ConstHostCollection
 MemHostDataSource::getAllbyHostname6(const std::string& hostname,
                                      const SubnetID& subnet_id) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Keep it when hostname and subnet_id match.
         if ((h->getLowerHostname() == hostname) &&
@@ -105,6 +113,7 @@ MemHostDataSource::getPage4(const SubnetID& subnet_id,
                             uint64_t lower_host_id,
                             const HostPageSize& page_size) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Skip it when subnet_id does not match.
         if (h->getIPv4SubnetID() != subnet_id) {
@@ -127,6 +136,7 @@ MemHostDataSource::getPage6(const SubnetID& subnet_id,
                             uint64_t lower_host_id,
                             const HostPageSize& page_size) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // Skip it when subnet_id does not match.
         if (h->getIPv6SubnetID() != subnet_id) {
@@ -148,6 +158,7 @@ MemHostDataSource::getPage4(size_t& /*source_index*/,
                             uint64_t lower_host_id,
                             const HostPageSize& page_size) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         if (lower_host_id && (h->getHostId() <= lower_host_id)) {
             continue;
@@ -165,6 +176,7 @@ MemHostDataSource::getPage6(size_t& /*source_index*/,
                             uint64_t lower_host_id,
                             const HostPageSize& page_size) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         if (lower_host_id && (h->getHostId() <= lower_host_id)) {
             continue;
@@ -180,6 +192,7 @@ MemHostDataSource::getPage6(size_t& /*source_index*/,
 ConstHostCollection
 MemHostDataSource::getAll4(const asiolink::IOAddress& address) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         if (h->getIPv4Reservation() == address) {
             hosts.push_back(h);
@@ -195,6 +208,7 @@ MemHostDataSource::get4(const SubnetID& subnet_id,
                         const uint8_t* identifier_begin,
                         const size_t identifier_len) const {
     vector<uint8_t> ident(identifier_begin, identifier_begin + identifier_len);
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // If either subnet-id or identifier type do not match,
         // it's not our host
@@ -218,6 +232,7 @@ MemHostDataSource::get6(const SubnetID& subnet_id,
                         const uint8_t* identifier_begin,
                         const size_t identifier_len) const {
     vector<uint8_t> ident(identifier_begin, identifier_begin + identifier_len);
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         // If either subnet-id or identifier type do not match,
         // it's not our host
@@ -237,6 +252,7 @@ MemHostDataSource::get6(const SubnetID& subnet_id,
 ConstHostPtr
 MemHostDataSource::get4(const SubnetID& subnet_id,
                         const asiolink::IOAddress& address) const {
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         if (h->getIPv4SubnetID() == subnet_id &&
             h->getIPv4Reservation() == address) {
@@ -251,6 +267,7 @@ ConstHostCollection
 MemHostDataSource::getAll4(const SubnetID& subnet_id,
                            const asiolink::IOAddress& address) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         if (h->getIPv4SubnetID() == subnet_id &&
             h->getIPv4Reservation() == address) {
@@ -270,6 +287,7 @@ MemHostDataSource::get6(const asiolink::IOAddress& /*prefix*/,
 ConstHostPtr
 MemHostDataSource::get6(const SubnetID& subnet_id,
                         const asiolink::IOAddress& address) const {
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
 
         // Naive approach: check hosts one by one
@@ -297,6 +315,7 @@ ConstHostCollection
 MemHostDataSource::getAll6(const SubnetID& subnet_id,
                            const asiolink::IOAddress& address) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         if (h->getIPv6SubnetID() != subnet_id) {
             continue;
@@ -316,6 +335,7 @@ MemHostDataSource::getAll6(const SubnetID& subnet_id,
 ConstHostCollection
 MemHostDataSource::getAll6(const asiolink::IOAddress& address) const {
     ConstHostCollection hosts;
+    MultiThreadingLock lock(mutex_);
     for (auto const& h : store_) {
         auto const& resrvs = h->getIPv6Reservations();
         BOOST_FOREACH(auto const& r, resrvs) {
@@ -330,6 +350,7 @@ MemHostDataSource::getAll6(const asiolink::IOAddress& address) const {
 
 void
 MemHostDataSource::add(const HostPtr& host) {
+    MultiThreadingLock lock(mutex_);
     host->setHostId(++next_host_id_);
     store_.push_back(host);
 }
@@ -337,6 +358,7 @@ MemHostDataSource::add(const HostPtr& host) {
 bool
 MemHostDataSource::del(const SubnetID& subnet_id,
                        const asiolink::IOAddress& addr) {
+    MultiThreadingLock lock(mutex_);
     for (auto h = store_.begin(); h != store_.end(); ++h) {
         if (addr.isV4()) {
             if ((*h)->getIPv4SubnetID() == subnet_id &&
@@ -371,6 +393,7 @@ MemHostDataSource::del4(const SubnetID& subnet_id,
                         const uint8_t* identifier_begin,
                         const size_t identifier_len) {
     vector<uint8_t> ident(identifier_begin, identifier_begin + identifier_len);
+    MultiThreadingLock lock(mutex_);
     for (auto h = store_.begin(); h != store_.end(); ++h) {
         // If either subnet-id or identifier type do not match,
         // it's not our host
@@ -394,6 +417,7 @@ MemHostDataSource::del6(const SubnetID& subnet_id,
                         const uint8_t* identifier_begin,
                         const size_t identifier_len) {
     vector<uint8_t> ident(identifier_begin, identifier_begin + identifier_len);
+    MultiThreadingLock lock(mutex_);
     for (auto h = store_.begin(); h != store_.end(); ++h) {
         // If either subnet-id or identifier type do not match,
         // it's not our host
@@ -412,6 +436,7 @@ MemHostDataSource::del6(const SubnetID& subnet_id,
 
 size_t
 MemHostDataSource::size() const {
+    MultiThreadingLock lock(mutex_);
     return (store_.size());
 }
 
index 6bca9f27d1002b1aa96ee6c1dff4e08721a01b7f..87adb889192ed181c9790a8cb479b38c0d1a3930 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018-2023 Internet Systems Consortium, Inc. ("ISC")
+// Copyright (C) 2018-2024 Internet Systems Consortium, Inc. ("ISC")
 //
 // This Source Code Form is subject to the terms of the Mozilla Public
 // License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -9,6 +9,7 @@
 
 #include <dhcpsrv/host_data_source_factory.h>
 #include <boost/shared_ptr.hpp>
+#include <mutex>
 #include <string>
 #include <vector>
 
@@ -333,6 +334,9 @@ protected:
 
     /// @brief Next host id
     uint64_t next_host_id_;
+
+    /// @brief Mutex to protect the store.
+    std::mutex mutable mutex_;
 };
 
 /// Pointer to the Mem host data source.