]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Convert DNSCrypt to SharedLockGuarded
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 30 Apr 2021 13:46:02 +0000 (15:46 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 17 Aug 2021 12:04:45 +0000 (14:04 +0200)
pdns/dnscrypt.cc
pdns/dnscrypt.hh

index d728a5c32e79a985a751b387cddf196c5debb9f8..71bfe6ced9d0255d5f31ff99e2431f8c89938644 100644 (file)
@@ -284,9 +284,9 @@ std::string DNSCryptContext::certificateDateToStr(uint32_t date)
 
 void DNSCryptContext::addNewCertificate(std::shared_ptr<DNSCryptCertificatePair>& newCert, bool reload)
 {
-  WriteLock w(&d_lock);
+  auto certs = d_certs.lock();
 
-  for (auto pair : d_certs) {
+  for (auto pair : *certs) {
     if (pair->cert.getSerial() == newCert->cert.getSerial()) {
       if (reload) {
         /* on reload we just assume that this is the same certificate */
@@ -298,7 +298,7 @@ void DNSCryptContext::addNewCertificate(std::shared_ptr<DNSCryptCertificatePair>
     }
   }
 
-  d_certs.push_back(newCert);
+  certs->push_back(newCert);
 }
 
 void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active, bool reload)
@@ -327,45 +327,33 @@ void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std:
   auto newPair = DNSCryptContext::loadCertificatePair(certFile, keyFile);
   newPair->active = active;
   addNewCertificate(newPair, reload);
-  {
-    WriteLock w(&d_lock);
-    d_certKeyPaths.push_back({certFile, keyFile});
-  }
+  d_certKeyPaths.lock()->push_back({certFile, keyFile});
 }
 
 void DNSCryptContext::reloadCertificates()
 {
   std::vector<std::shared_ptr<DNSCryptCertificatePair>> newCerts;
-
   {
-    ReadLock rl(&d_lock);
-    newCerts.reserve(d_certKeyPaths.size());
-    for (const auto& pair : d_certKeyPaths) {
+    auto paths = d_certKeyPaths.read_lock();
+    newCerts.reserve(paths->size());
+    for (const auto& pair : *paths) {
       newCerts.push_back(DNSCryptContext::loadCertificatePair(pair.cert, pair.key));
     }
   }
-
+    
   {
-    WriteLock w(&d_lock);
-    d_certs = std::move(newCerts);
+    *(d_certs.lock()) = std::move(newCerts);
   }
 }
 
 std::vector<std::shared_ptr<DNSCryptCertificatePair>> DNSCryptContext::getCertificates() {
-  std::vector<std::shared_ptr<DNSCryptCertificatePair>> ret;
-  {
-    ReadLock w(&d_lock);
-    ret = d_certs;
-  }
-
+  std::vector<std::shared_ptr<DNSCryptCertificatePair>> ret = *(d_certs.read_lock());
   return ret;
 };
 
 void DNSCryptContext::markActive(uint32_t serial)
 {
-  WriteLock w(&d_lock);
-
-  for (auto pair : d_certs) {
+  for (auto pair : *d_certs.lock()) {
     if (pair->active == false && pair->cert.getSerial() == serial) {
       pair->active = true;
       return;
@@ -376,9 +364,7 @@ void DNSCryptContext::markActive(uint32_t serial)
 
 void DNSCryptContext::markInactive(uint32_t serial)
 {
-  WriteLock w(&d_lock);
-
-  for (auto pair : d_certs) {
+  for (auto pair : *d_certs.lock()) {
     if (pair->active == true && pair->cert.getSerial() == serial) {
       pair->active = false;
       return;
@@ -389,11 +375,11 @@ void DNSCryptContext::markInactive(uint32_t serial)
 
 void DNSCryptContext::removeInactiveCertificate(uint32_t serial)
 {
-  WriteLock w(&d_lock);
+  auto certs = d_certs.lock();
 
-  for (auto it = d_certs.begin(); it != d_certs.end(); ) {
+  for (auto it = certs->begin(); it != certs->end(); ) {
     if ((*it)->active == false && (*it)->cert.getSerial() == serial) {
-      it = d_certs.erase(it);
+      it = certs->erase(it);
       return;
     } else {
       it++;
@@ -444,8 +430,8 @@ void DNSCryptContext::getCertificateResponse(time_t now, const DNSName& qname, u
   dh->qr = true;
   dh->rcode = RCode::NoError;
 
-  ReadLock r(&d_lock);
-  for (const auto& pair : d_certs) {
+  auto certs = d_certs.read_lock();
+  for (const auto& pair : *certs) {
     if (!pair->active || !pair->cert.isValid(now)) {
       continue;
     }
@@ -465,8 +451,8 @@ bool DNSCryptContext::magicMatchesAPublicKey(DNSCryptQuery& query, time_t now)
 {
   const unsigned char* magic = query.getClientMagic();
 
-  ReadLock r(&d_lock);
-  for (const auto& pair : d_certs) {
+  auto certs = d_certs.read_lock();
+  for (const auto& pair : *certs) {
     if (pair->cert.isValid(now) && memcmp(magic, pair->cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
       query.setCertificatePair(pair);
       return true;
index 0dec4c2883bec3de001fee446c18093a3c5bb426..ff4d94c4663965861c140ac4218e7253dbd191b3 100644 (file)
@@ -292,9 +292,8 @@ private:
 
   void addNewCertificate(std::shared_ptr<DNSCryptCertificatePair>& newCert, bool reload=false);
 
-  ReadWriteLock d_lock;
-  std::vector<std::shared_ptr<DNSCryptCertificatePair>> d_certs;
-  std::vector<CertKeyPaths> d_certKeyPaths;
+  SharedLockGuarded<std::vector<std::shared_ptr<DNSCryptCertificatePair>>> d_certs;
+  SharedLockGuarded<std::vector<CertKeyPaths>> d_certKeyPaths;
   DNSName providerName;
 };