]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnscrypt.cc
Merge pull request #9070 from rgacogne/boost-173
[thirdparty/pdns.git] / pdns / dnscrypt.cc
index 1f7456348739834b69a0f3dbd281201bcf7b6231..17eb99666fba4d7d13c309c92bf0f8bc41d1d8a3 100644 (file)
@@ -25,7 +25,6 @@
 #include "dolog.hh"
 #include "dnscrypt.hh"
 #include "dnswriter.hh"
-#include "lock.hh"
 
 DNSCryptPrivateKey::DNSCryptPrivateKey()
 {
@@ -123,17 +122,17 @@ DNSCryptQuery::~DNSCryptQuery()
 }
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
 
-DNSCryptContext::DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName)
-{
-  pthread_rwlock_init(&d_lock, 0);
 
-  loadNewCertificate(certFile, keyFile);
+DNSCryptContext::~DNSCryptContext() {
 }
 
-DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+DNSCryptContext::DNSCryptContext(const std::string& pName, const std::vector<CertKeyPaths>& certKeys): d_certKeyPaths(certKeys), providerName(pName)
 {
-  pthread_rwlock_init(&d_lock, 0);
+  reloadCertificates();
+}
 
+DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+{
   addNewCertificate(certificate, pKey);
 }
 
@@ -283,40 +282,72 @@ std::string DNSCryptContext::certificateDateToStr(uint32_t date)
   return string(buf);
 }
 
-void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active)
+void DNSCryptContext::addNewCertificate(std::shared_ptr<DNSCryptCertificatePair>& newCert, bool reload)
 {
   WriteLock w(&d_lock);
 
-  for (auto pair : certs) {
-    if (pair->cert.getSerial() == newCert.getSerial()) {
-      throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+  for (auto pair : d_certs) {
+    if (pair->cert.getSerial() == newCert->cert.getSerial()) {
+      if (reload) {
+        /* on reload we just assume that this is the same certificate */
+        return;
+      }
+      else {
+        throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+      }
     }
   }
 
+  d_certs.push_back(newCert);
+}
+
+void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active, bool reload)
+{
   auto pair = std::make_shared<DNSCryptCertificatePair>();
   pair->cert = newCert;
   pair->privateKey = newKey;
   computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
   pair->active = active;
-  certs.push_back(pair);
+
+  addNewCertificate(pair, reload);
 }
 
-void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active)
+std::shared_ptr<DNSCryptCertificatePair> DNSCryptContext::loadCertificatePair(const std::string& certFile, const std::string& keyFile)
 {
-  DNSCryptCert newCert;
-  DNSCryptPrivateKey newPrivateKey;
+  auto pair = std::make_shared<DNSCryptCertificatePair>();
+  loadCertFromFile(certFile, pair->cert);
+  pair->privateKey.loadFromFile(keyFile);
+  pair->active = true;
+  computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
+  return pair;
+}
+
+void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active, bool reload)
+{
+  auto newPair = DNSCryptContext::loadCertificatePair(certFile, keyFile);
+  newPair->active = active;
+  addNewCertificate(newPair, reload);
+  d_certKeyPaths.push_back({certFile, keyFile});
+}
 
-  loadCertFromFile(certFile, newCert);
-  newPrivateKey.loadFromFile(keyFile);
+void DNSCryptContext::reloadCertificates()
+{
+  std::vector<std::shared_ptr<DNSCryptCertificatePair>> newCerts;
+  for (const auto& pair : d_certKeyPaths) {
+    newCerts.push_back(DNSCryptContext::loadCertificatePair(pair.cert, pair.key));
+  }
 
-  addNewCertificate(newCert, newPrivateKey, active);
+  {
+    WriteLock w(&d_lock);
+    d_certs = std::move(newCerts);
+  }
 }
 
 void DNSCryptContext::markActive(uint32_t serial)
 {
   WriteLock w(&d_lock);
 
-  for (auto pair : certs) {
+  for (auto pair : d_certs) {
     if (pair->active == false && pair->cert.getSerial() == serial) {
       pair->active = true;
       return;
@@ -329,7 +360,7 @@ void DNSCryptContext::markInactive(uint32_t serial)
 {
   WriteLock w(&d_lock);
 
-  for (auto pair : certs) {
+  for (auto pair : d_certs) {
     if (pair->active == true && pair->cert.getSerial() == serial) {
       pair->active = false;
       return;
@@ -342,9 +373,9 @@ void DNSCryptContext::removeInactiveCertificate(uint32_t serial)
 {
   WriteLock w(&d_lock);
 
-  for (auto it = certs.begin(); it != certs.end(); ) {
+  for (auto it = d_certs.begin(); it != d_certs.end(); ) {
     if ((*it)->active == false && (*it)->cert.getSerial() == serial) {
-      it = certs.erase(it);
+      it = d_certs.erase(it);
       return;
     } else {
       it++;
@@ -393,7 +424,7 @@ void DNSCryptContext::getCertificateResponse(time_t now, const DNSName& qname, u
   dh->rcode = RCode::NoError;
 
   ReadLock r(&d_lock);
-  for (const auto pair : certs) {
+  for (const auto& pair : d_certs) {
     if (!pair->active || !pair->cert.isValid(now)) {
       continue;
     }
@@ -414,7 +445,7 @@ bool DNSCryptContext::magicMatchesAPublicKey(DNSCryptQuery& query, time_t now)
   const unsigned char* magic = query.getClientMagic();
 
   ReadLock r(&d_lock);
-  for (const auto& pair : certs) {
+  for (const auto& pair : d_certs) {
     if (pair->cert.isValid(now) && memcmp(magic, pair->cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
       query.setCertificatePair(pair);
       return true;
@@ -467,7 +498,7 @@ void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, ui
 
   unsigned char nonce[DNSCRYPT_NONCE_SIZE];
   static_assert(sizeof(nonce) == (2* sizeof(d_header.clientNonce)), "Nonce should be larger than clientNonce (half)");
-  static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right");
+  static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Public key size is not right");
   static_assert(sizeof(d_pair->privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right");
 
   memcpy(nonce, &d_header.clientNonce, sizeof(d_header.clientNonce));
@@ -713,7 +744,7 @@ int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_
   return res;
 }
 
-int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr<DNSCryptCert> cert) const
+int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr<DNSCryptCert>& cert) const
 {
   assert(query != nullptr);
   assert(queryLen > 0);