#include "dolog.hh"
#include "dnscrypt.hh"
#include "dnswriter.hh"
-#include "lock.hh"
DNSCryptPrivateKey::DNSCryptPrivateKey()
{
}
#endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
-DNSCryptContext::DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName)
-{
- pthread_rwlock_init(&d_lock, 0);
- loadNewCertificate(certFile, keyFile);
+DNSCryptContext::~DNSCryptContext() {
}
-DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+DNSCryptContext::DNSCryptContext(const std::string& pName, const std::vector<CertKeyPaths>& certKeys): d_certKeyPaths(certKeys), providerName(pName)
{
- pthread_rwlock_init(&d_lock, 0);
+ reloadCertificates();
+}
+DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+{
addNewCertificate(certificate, pKey);
}
return string(buf);
}
-void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active)
+void DNSCryptContext::addNewCertificate(std::shared_ptr<DNSCryptCertificatePair>& newCert, bool reload)
{
WriteLock w(&d_lock);
- for (auto pair : certs) {
- if (pair->cert.getSerial() == newCert.getSerial()) {
- throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+ for (auto pair : d_certs) {
+ if (pair->cert.getSerial() == newCert->cert.getSerial()) {
+ if (reload) {
+ /* on reload we just assume that this is the same certificate */
+ return;
+ }
+ else {
+ throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+ }
}
}
+ d_certs.push_back(newCert);
+}
+
+void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active, bool reload)
+{
auto pair = std::make_shared<DNSCryptCertificatePair>();
pair->cert = newCert;
pair->privateKey = newKey;
computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
pair->active = active;
- certs.push_back(pair);
+
+ addNewCertificate(pair, reload);
}
-void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active)
+std::shared_ptr<DNSCryptCertificatePair> DNSCryptContext::loadCertificatePair(const std::string& certFile, const std::string& keyFile)
{
- DNSCryptCert newCert;
- DNSCryptPrivateKey newPrivateKey;
+ auto pair = std::make_shared<DNSCryptCertificatePair>();
+ loadCertFromFile(certFile, pair->cert);
+ pair->privateKey.loadFromFile(keyFile);
+ pair->active = true;
+ computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
+ return pair;
+}
+
+void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active, bool reload)
+{
+ auto newPair = DNSCryptContext::loadCertificatePair(certFile, keyFile);
+ newPair->active = active;
+ addNewCertificate(newPair, reload);
+ d_certKeyPaths.push_back({certFile, keyFile});
+}
- loadCertFromFile(certFile, newCert);
- newPrivateKey.loadFromFile(keyFile);
+void DNSCryptContext::reloadCertificates()
+{
+ std::vector<std::shared_ptr<DNSCryptCertificatePair>> newCerts;
+ for (const auto& pair : d_certKeyPaths) {
+ newCerts.push_back(DNSCryptContext::loadCertificatePair(pair.cert, pair.key));
+ }
- addNewCertificate(newCert, newPrivateKey, active);
+ {
+ WriteLock w(&d_lock);
+ d_certs = std::move(newCerts);
+ }
}
void DNSCryptContext::markActive(uint32_t serial)
{
WriteLock w(&d_lock);
- for (auto pair : certs) {
+ for (auto pair : d_certs) {
if (pair->active == false && pair->cert.getSerial() == serial) {
pair->active = true;
return;
{
WriteLock w(&d_lock);
- for (auto pair : certs) {
+ for (auto pair : d_certs) {
if (pair->active == true && pair->cert.getSerial() == serial) {
pair->active = false;
return;
{
WriteLock w(&d_lock);
- for (auto it = certs.begin(); it != certs.end(); ) {
+ for (auto it = d_certs.begin(); it != d_certs.end(); ) {
if ((*it)->active == false && (*it)->cert.getSerial() == serial) {
- it = certs.erase(it);
+ it = d_certs.erase(it);
return;
} else {
it++;
dh->rcode = RCode::NoError;
ReadLock r(&d_lock);
- for (const auto pair : certs) {
+ for (const auto& pair : d_certs) {
if (!pair->active || !pair->cert.isValid(now)) {
continue;
}
const unsigned char* magic = query.getClientMagic();
ReadLock r(&d_lock);
- for (const auto& pair : certs) {
+ for (const auto& pair : d_certs) {
if (pair->cert.isValid(now) && memcmp(magic, pair->cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
query.setCertificatePair(pair);
return true;
unsigned char nonce[DNSCRYPT_NONCE_SIZE];
static_assert(sizeof(nonce) == (2* sizeof(d_header.clientNonce)), "Nonce should be larger than clientNonce (half)");
- static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right");
+ static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Public key size is not right");
static_assert(sizeof(d_pair->privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right");
memcpy(nonce, &d_header.clientNonce, sizeof(d_header.clientNonce));
return res;
}
-int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr<DNSCryptCert> cert) const
+int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr<DNSCryptCert>& cert) const
{
assert(query != nullptr);
assert(queryLen > 0);