]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnscrypt.cc
Merge pull request #9070 from rgacogne/boost-173
[thirdparty/pdns.git] / pdns / dnscrypt.cc
index 4cf126ce389d6afc7f8bb979670e34c461e458e9..17eb99666fba4d7d13c309c92bf0f8bc41d1d8a3 100644 (file)
@@ -25,7 +25,6 @@
 #include "dolog.hh"
 #include "dnscrypt.hh"
 #include "dnswriter.hh"
-#include "lock.hh"
 
 DNSCryptPrivateKey::DNSCryptPrivateKey()
 {
@@ -123,17 +122,17 @@ DNSCryptQuery::~DNSCryptQuery()
 }
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
 
-DNSCryptContext::DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName)
-{
-  pthread_rwlock_init(&d_lock, 0);
 
-  loadNewCertificate(certFile, keyFile);
+DNSCryptContext::~DNSCryptContext() {
 }
 
-DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+DNSCryptContext::DNSCryptContext(const std::string& pName, const std::vector<CertKeyPaths>& certKeys): d_certKeyPaths(certKeys), providerName(pName)
 {
-  pthread_rwlock_init(&d_lock, 0);
+  reloadCertificates();
+}
 
+DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+{
   addNewCertificate(certificate, pKey);
 }
 
@@ -283,12 +282,12 @@ std::string DNSCryptContext::certificateDateToStr(uint32_t date)
   return string(buf);
 }
 
-void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active, bool reload)
+void DNSCryptContext::addNewCertificate(std::shared_ptr<DNSCryptCertificatePair>& newCert, bool reload)
 {
   WriteLock w(&d_lock);
 
-  for (auto pair : certs) {
-    if (pair->cert.getSerial() == newCert.getSerial()) {
+  for (auto pair : d_certs) {
+    if (pair->cert.getSerial() == newCert->cert.getSerial()) {
       if (reload) {
         /* on reload we just assume that this is the same certificate */
         return;
@@ -299,37 +298,56 @@ void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCr
     }
   }
 
+  d_certs.push_back(newCert);
+}
+
+void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active, bool reload)
+{
   auto pair = std::make_shared<DNSCryptCertificatePair>();
   pair->cert = newCert;
   pair->privateKey = newKey;
   computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
   pair->active = active;
-  certs.push_back(pair);
+
+  addNewCertificate(pair, reload);
 }
 
-void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active, bool reload)
+std::shared_ptr<DNSCryptCertificatePair> DNSCryptContext::loadCertificatePair(const std::string& certFile, const std::string& keyFile)
 {
-  DNSCryptCert newCert;
-  DNSCryptPrivateKey newPrivateKey;
-
-  loadCertFromFile(certFile, newCert);
-  newPrivateKey.loadFromFile(keyFile);
+  auto pair = std::make_shared<DNSCryptCertificatePair>();
+  loadCertFromFile(certFile, pair->cert);
+  pair->privateKey.loadFromFile(keyFile);
+  pair->active = true;
+  computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
+  return pair;
+}
 
-  addNewCertificate(newCert, newPrivateKey, active, reload);
-  certificatePath = certFile;
-  keyPath = keyFile;
+void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active, bool reload)
+{
+  auto newPair = DNSCryptContext::loadCertificatePair(certFile, keyFile);
+  newPair->active = active;
+  addNewCertificate(newPair, reload);
+  d_certKeyPaths.push_back({certFile, keyFile});
 }
 
-void DNSCryptContext::reloadCertificate()
+void DNSCryptContext::reloadCertificates()
 {
-  loadNewCertificate(certificatePath, keyPath, true, true);
+  std::vector<std::shared_ptr<DNSCryptCertificatePair>> newCerts;
+  for (const auto& pair : d_certKeyPaths) {
+    newCerts.push_back(DNSCryptContext::loadCertificatePair(pair.cert, pair.key));
+  }
+
+  {
+    WriteLock w(&d_lock);
+    d_certs = std::move(newCerts);
+  }
 }
 
 void DNSCryptContext::markActive(uint32_t serial)
 {
   WriteLock w(&d_lock);
 
-  for (auto pair : certs) {
+  for (auto pair : d_certs) {
     if (pair->active == false && pair->cert.getSerial() == serial) {
       pair->active = true;
       return;
@@ -342,7 +360,7 @@ void DNSCryptContext::markInactive(uint32_t serial)
 {
   WriteLock w(&d_lock);
 
-  for (auto pair : certs) {
+  for (auto pair : d_certs) {
     if (pair->active == true && pair->cert.getSerial() == serial) {
       pair->active = false;
       return;
@@ -355,9 +373,9 @@ void DNSCryptContext::removeInactiveCertificate(uint32_t serial)
 {
   WriteLock w(&d_lock);
 
-  for (auto it = certs.begin(); it != certs.end(); ) {
+  for (auto it = d_certs.begin(); it != d_certs.end(); ) {
     if ((*it)->active == false && (*it)->cert.getSerial() == serial) {
-      it = certs.erase(it);
+      it = d_certs.erase(it);
       return;
     } else {
       it++;
@@ -406,7 +424,7 @@ void DNSCryptContext::getCertificateResponse(time_t now, const DNSName& qname, u
   dh->rcode = RCode::NoError;
 
   ReadLock r(&d_lock);
-  for (const auto pair : certs) {
+  for (const auto& pair : d_certs) {
     if (!pair->active || !pair->cert.isValid(now)) {
       continue;
     }
@@ -427,7 +445,7 @@ bool DNSCryptContext::magicMatchesAPublicKey(DNSCryptQuery& query, time_t now)
   const unsigned char* magic = query.getClientMagic();
 
   ReadLock r(&d_lock);
-  for (const auto& pair : certs) {
+  for (const auto& pair : d_certs) {
     if (pair->cert.isValid(now) && memcmp(magic, pair->cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
       query.setCertificatePair(pair);
       return true;
@@ -480,7 +498,7 @@ void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, ui
 
   unsigned char nonce[DNSCRYPT_NONCE_SIZE];
   static_assert(sizeof(nonce) == (2* sizeof(d_header.clientNonce)), "Nonce should be larger than clientNonce (half)");
-  static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right");
+  static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Public key size is not right");
   static_assert(sizeof(d_pair->privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right");
 
   memcpy(nonce, &d_header.clientNonce, sizeof(d_header.clientNonce));