From: Otto Date: Wed, 3 Feb 2021 09:04:32 +0000 (+0100) Subject: Baseline for DoT integration into sdig, taken from Habbie/sdig-dot-pin X-Git-Tag: dnsdist-1.6.0-alpha2~59^2~8 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fcff37fa8f716a03bcdcbf56fe6aaeb8f46c3dd4;p=thirdparty%2Fpdns.git Baseline for DoT integration into sdig, taken from Habbie/sdig-dot-pin --- diff --git a/configure.ac b/configure.ac index 2853015a03..54e3b9fc42 100644 --- a/configure.ac +++ b/configure.ac @@ -92,6 +92,18 @@ PDNS_CHECK_LIBCRYPTO([ PDNS_CHECK_LIBCRYPTO_ECDSA PDNS_CHECK_LIBCRYPTO_EDDSA +AM_CONDITIONAL([HAVE_GNUTLS], [false]) +AM_CONDITIONAL([HAVE_LIBSSL], [false]) + +PDNS_ENABLE_DNS_OVER_TLS +AS_IF([test "x$enable_dns_over_tls" != "xno"], [ + PDNS_WITH_LIBSSL + PDNS_WITH_GNUTLS + AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [ + AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available]) + ]) +]) + PDNS_CHECK_RAGEL([pdns/dnslabeltext.cc], [www.powerdns.com]) PDNS_CHECK_CLOCK_GETTIME diff --git a/pdns/dnsdistdist/m4/dnsdist_enable_tls.m4 b/m4/pdns_enable_tls.m4 similarity index 92% rename from pdns/dnsdistdist/m4/dnsdist_enable_tls.m4 rename to m4/pdns_enable_tls.m4 index 0a9539faff..a31591f09d 100644 --- a/pdns/dnsdistdist/m4/dnsdist_enable_tls.m4 +++ b/m4/pdns_enable_tls.m4 @@ -1,4 +1,4 @@ -AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_TLS], [ +AC_DEFUN([PDNS_ENABLE_DNS_OVER_TLS], [ AC_MSG_CHECKING([whether to enable DNS over TLS support]) AC_ARG_ENABLE([dns-over-tls], AS_HELP_STRING([--enable-dns-over-tls], [enable DNS over TLS support (requires GnuTLS or OpenSSL) @<:@default=no@:>@]), diff --git a/pdns/dnsdistdist/m4/dnsdist_with_gnutls.m4 b/m4/pdns_with_gnutls.m4 similarity index 96% rename from pdns/dnsdistdist/m4/dnsdist_with_gnutls.m4 rename to m4/pdns_with_gnutls.m4 index c1d849c156..3bfae0245c 100644 --- a/pdns/dnsdistdist/m4/dnsdist_with_gnutls.m4 +++ b/m4/pdns_with_gnutls.m4 @@ -1,4 +1,4 @@ -AC_DEFUN([DNSDIST_WITH_GNUTLS], [ +AC_DEFUN([PDNS_WITH_GNUTLS], [ AC_MSG_CHECKING([whether we will be linking in GnuTLS]) HAVE_GNUTLS=0 AC_ARG_WITH([gnutls], diff --git a/pdns/dnsdistdist/m4/dnsdist_with_libssl.m4 b/m4/pdns_with_libssl.m4 similarity index 97% rename from pdns/dnsdistdist/m4/dnsdist_with_libssl.m4 rename to m4/pdns_with_libssl.m4 index 8706e2e2c7..8a8438854c 100644 --- a/pdns/dnsdistdist/m4/dnsdist_with_libssl.m4 +++ b/m4/pdns_with_libssl.m4 @@ -1,4 +1,4 @@ -AC_DEFUN([DNSDIST_WITH_LIBSSL], [ +AC_DEFUN([PDNS_WITH_LIBSSL], [ AC_MSG_CHECKING([whether we will be linking in OpenSSL libssl]) HAVE_LIBSSL=0 AC_ARG_WITH([libssl], diff --git a/pdns/Makefile.am b/pdns/Makefile.am index fbcb5deb38..0f7fffb0e1 100644 --- a/pdns/Makefile.am +++ b/pdns/Makefile.am @@ -518,6 +518,7 @@ sdig_SOURCES = \ dnsrecords.cc \ dnswriter.cc dnswriter.hh \ ednssubnet.cc iputils.cc \ + libssl.cc libssl.hh \ logger.cc \ misc.cc misc.hh \ nsecrecords.cc \ @@ -529,8 +530,10 @@ sdig_SOURCES = \ sstuff.hh \ statbag.cc \ svc-records.cc svc-records.hh \ + tcpiohandler.cc tcpiohandler.hh \ unix_utility.cc +sdig_CPPFLAGS = $(AM_CPPFLAGS) sdig_LDADD = $(LIBCRYPTO_LIBS) sdig_LDFLAGS = $(AM_LDFLAGS) $(LIBCRYPTO_LDFLAGS) @@ -539,6 +542,25 @@ sdig_SOURCES += minicurl.cc minicurl.hh sdig_LDADD += $(LIBCURL) endif +if HAVE_DNS_OVER_TLS + +if HAVE_GNUTLS +sdig_CPPFLAGS += $(GNUTLS_CFLAGS) +sdig_LDADD += -lgnutls +endif + +if HAVE_LIBSSL +sdig_CPPFLAGS += $(LIBSSL_CFLAGS) +sdig_LDADD += $(LIBSSL_LIBS) +endif + +if LIBSODIUM +sdig_CPPFLAGS +=$(LIBSODIUM_CFLAGS) +sdig_LDADD += $(LIBSODIUM_LIBS) +endif + +endif + calidns_SOURCES = \ base32.cc \ base64.cc base64.hh \ diff --git a/pdns/dnsdistdist/configure.ac b/pdns/dnsdistdist/configure.ac index 2aad28d843..46a7b4f1a6 100644 --- a/pdns/dnsdistdist/configure.ac +++ b/pdns/dnsdistdist/configure.ac @@ -68,15 +68,15 @@ AM_CONDITIONAL([HAVE_CDB], [false]) PDNS_CHECK_LIBCRYPTO -DNSDIST_ENABLE_DNS_OVER_TLS +PDNS_ENABLE_DNS_OVER_TLS DNSDIST_ENABLE_DNS_OVER_HTTPS AS_IF([test "x$enable_dns_over_tls" != "xno" -o "x$enable_dns_over_https" != "xno"], [ - DNSDIST_WITH_LIBSSL + PDNS_WITH_LIBSSL ]) AS_IF([test "x$enable_dns_over_tls" != "xno"], [ - DNSDIST_WITH_GNUTLS + PDNS_WITH_GNUTLS AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [ AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available]) diff --git a/pdns/dnsdistdist/libssl.cc b/pdns/dnsdistdist/libssl.cc deleted file mode 100644 index deffcdbf49..0000000000 --- a/pdns/dnsdistdist/libssl.cc +++ /dev/null @@ -1,785 +0,0 @@ - -#include "config.h" -#include "libssl.hh" - -#ifdef HAVE_LIBSSL - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#ifdef HAVE_LIBSODIUM -#include -#endif /* HAVE_LIBSODIUM */ - -#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) -/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */ - -#include "lock.hh" -static std::vector openssllocks; - -extern "C" { -static void openssl_pthreads_locking_callback(int mode, int type, const char *file, int line) -{ - if (mode & CRYPTO_LOCK) { - openssllocks.at(type).lock(); - - } else { - openssllocks.at(type).unlock(); - } -} - -static unsigned long openssl_pthreads_id_callback() -{ - return (unsigned long)pthread_self(); -} -} - -static void openssl_thread_setup() -{ - openssllocks = std::vector(CRYPTO_num_locks()); - CRYPTO_set_id_callback(&openssl_pthreads_id_callback); - CRYPTO_set_locking_callback(&openssl_pthreads_locking_callback); -} - -static void openssl_thread_cleanup() -{ - CRYPTO_set_locking_callback(nullptr); - openssllocks.clear(); -} - -#endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) */ - -static std::atomic s_users; -static int s_ticketsKeyIndex{-1}; -static int s_countersIndex{-1}; -static int s_keyLogIndex{-1}; - -void registerOpenSSLUser() -{ - if (s_users.fetch_add(1) == 0) { -#ifdef HAVE_OPENSSL_INIT_CRYPTO - /* load the default configuration file (or one specified via OPENSSL_CONF), - which can then be used to load engines */ - OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CONFIG, nullptr); -#endif - -#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL)) - SSL_load_error_strings(); - OpenSSL_add_ssl_algorithms(); - openssl_thread_setup(); -#endif - s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - - if (s_ticketsKeyIndex == -1) { - throw std::runtime_error("Error getting an index for tickets key"); - } - - s_countersIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - - if (s_countersIndex == -1) { - throw std::runtime_error("Error getting an index for counters"); - } - - s_keyLogIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - - if (s_keyLogIndex == -1) { - throw std::runtime_error("Error getting an index for TLS key logging"); - } - } -} - -void unregisterOpenSSLUser() -{ - if (s_users.fetch_sub(1) == 1) { -#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL)) - ERR_free_strings(); - - EVP_cleanup(); - - CONF_modules_finish(); - CONF_modules_free(); - CONF_modules_unload(1); - - CRYPTO_cleanup_all_ex_data(); - openssl_thread_cleanup(); -#endif - } -} - -void* libssl_get_ticket_key_callback_data(SSL* s) -{ - SSL_CTX* sslCtx = SSL_get_SSL_CTX(s); - if (sslCtx == nullptr) { - return nullptr; - } - - return SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex); -} - -void libssl_set_ticket_key_callback_data(SSL_CTX* ctx, void* data) -{ - SSL_CTX_set_ex_data(ctx, s_ticketsKeyIndex, data); -} - -int libssl_ticket_key_callback(SSL *s, OpenSSLTLSTicketKeysRing& keyring, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) -{ - if (enc) { - const auto key = keyring.getEncryptionKey(); - if (key == nullptr) { - return -1; - } - - return key->encrypt(keyName, iv, ectx, hctx); - } - - bool activeEncryptionKey = false; - - const auto key = keyring.getDecryptionKey(keyName, activeEncryptionKey); - if (key == nullptr) { - /* we don't know this key, just create a new ticket */ - return 0; - } - - if (key->decrypt(iv, ectx, hctx) == false) { - return -1; - } - - if (!activeEncryptionKey) { - /* this key is not active, please encrypt the ticket content with the currently active one */ - return 2; - } - - return 1; -} - -static long libssl_server_name_callback(SSL* ssl, int* al, void* arg) -{ - (void) al; - (void) arg; - - if (SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name)) { - return SSL_TLSEXT_ERR_OK; - } - - return SSL_TLSEXT_ERR_NOACK; -} - -static void libssl_info_callback(const SSL *ssl, int where, int ret) -{ - SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl); - if (sslCtx == nullptr) { - return; - } - - TLSErrorCounters* counters = reinterpret_cast(SSL_CTX_get_ex_data(sslCtx, s_countersIndex)); - if (counters == nullptr) { - return; - } - - if (where & SSL_CB_ALERT) { - const long lastError = ERR_peek_last_error(); - switch (ERR_GET_REASON(lastError)) { -#ifdef SSL_R_DH_KEY_TOO_SMALL - case SSL_R_DH_KEY_TOO_SMALL: - ++counters->d_dhKeyTooSmall; - break; -#endif /* SSL_R_DH_KEY_TOO_SMALL */ - case SSL_R_NO_SHARED_CIPHER: - ++counters->d_noSharedCipher; - break; - case SSL_R_UNKNOWN_PROTOCOL: - ++counters->d_unknownProtocol; - break; - case SSL_R_UNSUPPORTED_PROTOCOL: -#ifdef SSL_R_VERSION_TOO_LOW - case SSL_R_VERSION_TOO_LOW: -#endif /* SSL_R_VERSION_TOO_LOW */ - ++counters->d_unsupportedProtocol; - break; - case SSL_R_INAPPROPRIATE_FALLBACK: - ++counters->d_inappropriateFallBack; - break; - case SSL_R_UNKNOWN_CIPHER_TYPE: - ++counters->d_unknownCipherType; - break; - case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE: - ++counters->d_unknownKeyExchangeType; - break; - case SSL_R_UNSUPPORTED_ELLIPTIC_CURVE: - ++counters->d_unsupportedEC; - break; - default: - break; - } - } -} - -void libssl_set_error_counters_callback(std::unique_ptr& ctx, TLSErrorCounters* counters) -{ - SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters); - SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback); -} - -int libssl_ocsp_stapling_callback(SSL* ssl, const std::map& ocspMap) -{ - auto pkey = SSL_get_privatekey(ssl); - if (pkey == nullptr) { - return SSL_TLSEXT_ERR_NOACK; - } - - /* look for an OCSP response for the corresponding private key type (RSA, ECDSA..) */ - const auto& data = ocspMap.find(EVP_PKEY_base_id(pkey)); - if (data == ocspMap.end()) { - return SSL_TLSEXT_ERR_NOACK; - } - - /* we need to allocate a copy because OpenSSL will free the pointer passed to SSL_set_tlsext_status_ocsp_resp() */ - void* copy = OPENSSL_malloc(data->second.size()); - if (copy == nullptr) { - return SSL_TLSEXT_ERR_NOACK; - } - - memcpy(copy, data->second.data(), data->second.size()); - SSL_set_tlsext_status_ocsp_resp(ssl, copy, data->second.size()); - return SSL_TLSEXT_ERR_OK; -} - -static bool libssl_validate_ocsp_response(const std::string& response) -{ - auto responsePtr = reinterpret_cast(response.data()); - std::unique_ptr resp(d2i_OCSP_RESPONSE(nullptr, &responsePtr, response.size()), OCSP_RESPONSE_free); - if (resp == nullptr) { - throw std::runtime_error("Unable to parse OCSP response"); - } - - int status = OCSP_response_status(resp.get()); - if (status != OCSP_RESPONSE_STATUS_SUCCESSFUL) { - throw std::runtime_error("OCSP response status is not successful: " + std::to_string(status)); - } - - std::unique_ptr basic(OCSP_response_get1_basic(resp.get()), OCSP_BASICRESP_free); - if (basic == nullptr) { - throw std::runtime_error("Error getting a basic OCSP response"); - } - - if (OCSP_resp_count(basic.get()) != 1) { - throw std::runtime_error("More than one single response in an OCSP basic response"); - } - - auto singleResponse = OCSP_resp_get0(basic.get(), 0); - if (singleResponse == nullptr) { - throw std::runtime_error("Error getting a single response from the basic OCSP response"); - } - - int reason; - ASN1_GENERALIZEDTIME* revTime = nullptr; - ASN1_GENERALIZEDTIME* thisUpdate = nullptr; - ASN1_GENERALIZEDTIME* nextUpdate = nullptr; - - auto singleResponseStatus = OCSP_single_get0_status(singleResponse, &reason, &revTime, &thisUpdate, &nextUpdate); - if (singleResponseStatus != V_OCSP_CERTSTATUS_GOOD) { - throw std::runtime_error("Invalid status for OCSP single response (" + std::to_string(singleResponseStatus) + ")"); - } - if (thisUpdate == nullptr || nextUpdate == nullptr) { - throw std::runtime_error("Error getting validity of OCSP single response"); - } - - auto validityResult = OCSP_check_validity(thisUpdate, nextUpdate, /* 5 minutes of leeway */ 5 * 60, -1); - if (validityResult == 0) { - throw std::runtime_error("OCSP single response is not yet, or no longer, valid"); - } - - return true; -} - -std::map libssl_load_ocsp_responses(const std::vector& ocspFiles, std::vector keyTypes) -{ - std::map ocspResponses; - - if (ocspFiles.size() > keyTypes.size()) { - throw std::runtime_error("More OCSP files than certificates and keys loaded!"); - } - - size_t count = 0; - for (const auto& filename : ocspFiles) { - std::ifstream file(filename, std::ios::binary); - std::string content; - while(file) { - char buffer[4096]; - file.read(buffer, sizeof(buffer)); - if (file.bad()) { - file.close(); - throw std::runtime_error("Unable to load OCSP response from '" + filename + "'"); - } - content.append(buffer, file.gcount()); - } - file.close(); - - try { - libssl_validate_ocsp_response(content); - ocspResponses.insert({keyTypes.at(count), std::move(content)}); - } - catch (const std::exception& e) { - throw std::runtime_error("Error checking the validity of OCSP response from '" + filename + "': " + e.what()); - } - ++count; - } - - return ocspResponses; -} - -int libssl_get_last_key_type(std::unique_ptr& ctx) -{ -#ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY - auto pkey = SSL_CTX_get0_privatekey(ctx.get()); -#else - auto temp = std::unique_ptr(SSL_new(ctx.get()), SSL_free); - if (!temp) { - return -1; - } - auto pkey = SSL_get_privatekey(temp.get()); -#endif - - if (!pkey) { - return -1; - } - - return EVP_PKEY_base_id(pkey); -} - -#ifdef HAVE_OCSP_BASIC_SIGN -bool libssl_generate_ocsp_response(const std::string& certFile, const std::string& caCert, const std::string& caKey, const std::string& outFile, int ndays, int nmin) -{ - const EVP_MD* rmd = EVP_sha256(); - - auto fp = std::unique_ptr(fopen(certFile.c_str(), "r"), fclose); - if (!fp) { - throw std::runtime_error("Unable to open '" + certFile + "' when loading the certificate to generate an OCSP response"); - } - auto cert = std::unique_ptr(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free); - - fp = std::unique_ptr(fopen(caCert.c_str(), "r"), fclose); - if (!fp) { - throw std::runtime_error("Unable to open '" + caCert + "' when loading the issuer certificate to generate an OCSP response"); - } - auto issuer = std::unique_ptr(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free); - fp = std::unique_ptr(fopen(caKey.c_str(), "r"), fclose); - if (!fp) { - throw std::runtime_error("Unable to open '" + caKey + "' when loading the issuer key to generate an OCSP response"); - } - auto issuerKey = std::unique_ptr(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr), EVP_PKEY_free); - fp.reset(); - - auto bs = std::unique_ptr(OCSP_BASICRESP_new(), OCSP_BASICRESP_free); - auto thisupd = std::unique_ptr(X509_gmtime_adj(nullptr, 0), ASN1_TIME_free); - auto nextupd = std::unique_ptr(X509_time_adj_ex(nullptr, ndays, nmin * 60, nullptr), ASN1_TIME_free); - - auto cid = std::unique_ptr(OCSP_cert_to_id(rmd, cert.get(), issuer.get()), OCSP_CERTID_free); - OCSP_basic_add1_status(bs.get(), cid.get(), V_OCSP_CERTSTATUS_GOOD, 0, nullptr, thisupd.get(), nextupd.get()); - - if (OCSP_basic_sign(bs.get(), issuer.get(), issuerKey.get(), rmd, nullptr, OCSP_NOCERTS) != 1) { - throw std::runtime_error("Error while signing the OCSP response"); - } - - auto resp = std::unique_ptr(OCSP_response_create(OCSP_RESPONSE_STATUS_SUCCESSFUL, bs.get()), OCSP_RESPONSE_free); - auto bio = std::unique_ptr(BIO_new_file(outFile.c_str(), "wb"), BIO_vfree); - if (!bio) { - throw std::runtime_error("Error opening file for writing the OCSP response"); - } - - // i2d_OCSP_RESPONSE_bio(bio.get(), resp.get()) is unusable from C++ because of an invalid cast - ASN1_i2d_bio((i2d_of_void*)i2d_OCSP_RESPONSE, bio.get(), (unsigned char*)resp.get()); - - return true; -} -#endif /* HAVE_OCSP_BASIC_SIGN */ - -LibsslTLSVersion libssl_tls_version_from_string(const std::string& str) -{ - if (str == "tls1.0") { - return LibsslTLSVersion::TLS10; - } - if (str == "tls1.1") { - return LibsslTLSVersion::TLS11; - } - if (str == "tls1.2") { - return LibsslTLSVersion::TLS12; - } - if (str == "tls1.3") { - return LibsslTLSVersion::TLS13; - } - throw std::runtime_error("Unknown TLS version '" + str); -} - -const std::string& libssl_tls_version_to_string(LibsslTLSVersion version) -{ - static const std::map versions = { - { LibsslTLSVersion::TLS10, "tls1.0" }, - { LibsslTLSVersion::TLS11, "tls1.1" }, - { LibsslTLSVersion::TLS12, "tls1.2" }, - { LibsslTLSVersion::TLS13, "tls1.3" } - }; - - const auto& it = versions.find(version); - if (it == versions.end()) { - throw std::runtime_error("Unknown TLS version (" + std::to_string((int)version) + ")"); - } - return it->second; -} - -bool libssl_set_min_tls_version(std::unique_ptr& ctx, LibsslTLSVersion version) -{ -#if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version) - /* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated - Warning: SSL_CTX_set_min_proto_version is a function-like macro in OpenSSL */ - int vers; - switch(version) { - case LibsslTLSVersion::TLS10: - vers = TLS1_VERSION; - break; - case LibsslTLSVersion::TLS11: - vers = TLS1_1_VERSION; - break; - case LibsslTLSVersion::TLS12: - vers = TLS1_2_VERSION; - break; - case LibsslTLSVersion::TLS13: -#ifdef TLS1_3_VERSION - vers = TLS1_3_VERSION; -#else - return false; -#endif /* TLS1_3_VERSION */ - break; - default: - return false; - } - - if (SSL_CTX_set_min_proto_version(ctx.get(), vers) != 1) { - return false; - } - return true; -#else - long vers = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3; - switch(version) { - case LibsslTLSVersion::TLS10: - break; - case LibsslTLSVersion::TLS11: - vers |= SSL_OP_NO_TLSv1; - break; - case LibsslTLSVersion::TLS12: - vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1; - break; - case LibsslTLSVersion::TLS13: - vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2; - break; - default: - return false; - } - - long options = SSL_CTX_get_options(ctx.get()); - SSL_CTX_set_options(ctx.get(), options | vers); - return true; -#endif -} - -OpenSSLTLSTicketKeysRing::OpenSSLTLSTicketKeysRing(size_t capacity) -{ - d_ticketKeys.set_capacity(capacity); -} - -OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() -{ -} - -void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr newKey) -{ - WriteLock wl(&d_lock); - d_ticketKeys.push_front(newKey); -} - -std::shared_ptr OpenSSLTLSTicketKeysRing::getEncryptionKey() -{ - ReadLock rl(&d_lock); - return d_ticketKeys.front(); -} - -std::shared_ptr OpenSSLTLSTicketKeysRing::getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey) -{ - ReadLock rl(&d_lock); - for (auto& key : d_ticketKeys) { - if (key->nameMatches(name)) { - activeKey = (key == d_ticketKeys.front()); - return key; - } - } - return nullptr; -} - -size_t OpenSSLTLSTicketKeysRing::getKeysCount() -{ - ReadLock rl(&d_lock); - return d_ticketKeys.size(); -} - -void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile) -{ - bool keyLoaded = false; - std::ifstream file(keyFile); - try { - do { - auto newKey = std::make_shared(file); - addKey(newKey); - keyLoaded = true; - } - while (!file.fail()); - } - catch (const std::exception& e) { - /* if we haven't been able to load at least one key, fail */ - if (!keyLoaded) { - throw; - } - } - - file.close(); -} - -void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t now) -{ - auto newKey = std::make_shared(); - addKey(newKey); -} - -OpenSSLTLSTicketKey::OpenSSLTLSTicketKey() -{ - if (RAND_bytes(d_name, sizeof(d_name)) != 1) { - throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key"); - } - - if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) { - throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key"); - } - - if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) { - throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key"); - } -#ifdef HAVE_LIBSODIUM - sodium_mlock(d_name, sizeof(d_name)); - sodium_mlock(d_cipherKey, sizeof(d_cipherKey)); - sodium_mlock(d_hmacKey, sizeof(d_hmacKey)); -#endif /* HAVE_LIBSODIUM */ -} - -OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(ifstream& file) -{ - file.read(reinterpret_cast(d_name), sizeof(d_name)); - file.read(reinterpret_cast(d_cipherKey), sizeof(d_cipherKey)); - file.read(reinterpret_cast(d_hmacKey), sizeof(d_hmacKey)); - - if (file.fail()) { - throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file"); - } -#ifdef HAVE_LIBSODIUM - sodium_mlock(d_name, sizeof(d_name)); - sodium_mlock(d_cipherKey, sizeof(d_cipherKey)); - sodium_mlock(d_hmacKey, sizeof(d_hmacKey)); -#endif /* HAVE_LIBSODIUM */ -} - -OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey() -{ -#ifdef HAVE_LIBSODIUM - sodium_munlock(d_name, sizeof(d_name)); - sodium_munlock(d_cipherKey, sizeof(d_cipherKey)); - sodium_munlock(d_hmacKey, sizeof(d_hmacKey)); -#else - OPENSSL_cleanse(d_name, sizeof(d_name)); - OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey)); - OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey)); -#endif /* HAVE_LIBSODIUM */ -} - -bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const -{ - return (memcmp(d_name, name, sizeof(d_name)) == 0); -} - -int OpenSSLTLSTicketKey::encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const -{ - memcpy(keyName, d_name, sizeof(d_name)); - - if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) { - return -1; - } - - if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) { - return -1; - } - - if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) { - return -1; - } - - return 1; -} - -bool OpenSSLTLSTicketKey::decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const -{ - if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) { - return false; - } - - if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) { - return false; - } - - return true; -} - -std::unique_ptr libssl_init_server_context(const TLSConfig& config, - std::map& ocspResponses) -{ - auto ctx = std::unique_ptr(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free); - - int sslOptions = - SSL_OP_NO_SSLv2 | - SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | - SSL_OP_SINGLE_DH_USE | - SSL_OP_SINGLE_ECDH_USE; - - if (!config.d_enableTickets || config.d_numberOfTicketsKeys == 0) { - /* for TLS 1.3 this means no stateless tickets, but stateful tickets might still be issued, - which is something we don't want. */ - sslOptions |= SSL_OP_NO_TICKET; - /* really disable all tickets */ -#ifdef HAVE_SSL_CTX_SET_NUM_TICKETS - SSL_CTX_set_num_tickets(ctx.get(), 0); -#endif /* HAVE_SSL_CTX_SET_NUM_TICKETS */ - } - - if (config.d_sessionTimeout > 0) { - SSL_CTX_set_timeout(ctx.get(), config.d_sessionTimeout); - } - - if (config.d_preferServerCiphers) { - sslOptions |= SSL_OP_CIPHER_SERVER_PREFERENCE; -#ifdef SSL_OP_PRIORITIZE_CHACHA - sslOptions |= SSL_OP_PRIORITIZE_CHACHA; -#endif /* SSL_OP_PRIORITIZE_CHACHA */ - } - - SSL_CTX_set_options(ctx.get(), sslOptions); - if (!libssl_set_min_tls_version(ctx, config.d_minTLSVersion)) { - throw std::runtime_error("Failed to set the minimum version to '" + libssl_tls_version_to_string(config.d_minTLSVersion)); - } - -#ifdef SSL_CTX_set_ecdh_auto - SSL_CTX_set_ecdh_auto(ctx.get(), 1); -#endif - - if (config.d_maxStoredSessions == 0) { - /* disable stored sessions entirely */ - SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_OFF); - } - else { - /* use the internal built-in cache to store sessions */ - SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_SERVER); - SSL_CTX_sess_set_cache_size(ctx.get(), config.d_maxStoredSessions); - } - - /* we need to set this callback to acknowledge the server name sent by the client, - otherwise it will not stored in the session and will not be accessible when the - session is resumed, causing SSL_get_servername to return nullptr */ - SSL_CTX_set_tlsext_servername_callback(ctx.get(), &libssl_server_name_callback); - - std::vector keyTypes; - /* load certificate and private key */ - for (const auto& pair : config.d_certKeyPairs) { - if (SSL_CTX_use_certificate_chain_file(ctx.get(), pair.first.c_str()) != 1) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.first); - } - if (SSL_CTX_use_PrivateKey_file(ctx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.second); - } - if (SSL_CTX_check_private_key(ctx.get()) != 1) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("The key from '" + pair.second + "' does not match the certificate from '" + pair.first + "'"); - } - /* store the type of the new key, we might need it later to select the right OCSP stapling response */ - auto keyType = libssl_get_last_key_type(ctx); - if (keyType < 0) { - throw std::runtime_error("The key from '" + pair.second + "' has an unknown type"); - } - keyTypes.push_back(keyType); - } - - if (!config.d_ocspFiles.empty()) { - try { - ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, keyTypes); - } - catch(const std::exception& e) { - throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what())); - } - } - - if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) { - throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers); - } - -#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES - if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) { - throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13); - } -#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */ - - return ctx; -} - -#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK -static void libssl_key_log_file_callback(const SSL* ssl, const char* line) -{ - SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl); - if (sslCtx == nullptr) { - return; - } - - auto fp = reinterpret_cast(SSL_CTX_get_ex_data(sslCtx, s_keyLogIndex)); - if (fp == nullptr) { - return; - } - - fprintf(fp, "%s\n", line); - fflush(fp); -} -#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */ - -std::unique_ptr libssl_set_key_log_file(std::unique_ptr& ctx, const std::string& logFile) -{ -#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK - auto fp = std::unique_ptr(fopen(logFile.c_str(), "a"), fclose); - if (!fp) { - throw std::runtime_error("Error opening TLS log file '" + logFile + "'"); - } - - SSL_CTX_set_ex_data(ctx.get(), s_keyLogIndex, fp.get()); - SSL_CTX_set_keylog_callback(ctx.get(), &libssl_key_log_file_callback); - - return fp; -#else - return std::unique_ptr(nullptr, fclose); -#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */ -} - -#endif /* HAVE_LIBSSL */ diff --git a/pdns/dnsdistdist/libssl.cc b/pdns/dnsdistdist/libssl.cc new file mode 120000 index 0000000000..49fc2fe043 --- /dev/null +++ b/pdns/dnsdistdist/libssl.cc @@ -0,0 +1 @@ +../libssl.cc \ No newline at end of file diff --git a/pdns/dnsdistdist/m4/pdns_enable_tls.m4 b/pdns/dnsdistdist/m4/pdns_enable_tls.m4 new file mode 120000 index 0000000000..6e0eb49006 --- /dev/null +++ b/pdns/dnsdistdist/m4/pdns_enable_tls.m4 @@ -0,0 +1 @@ +../../../m4/pdns_enable_tls.m4 \ No newline at end of file diff --git a/pdns/dnsdistdist/m4/pdns_with_gnutls.m4 b/pdns/dnsdistdist/m4/pdns_with_gnutls.m4 new file mode 120000 index 0000000000..b892c7fb49 --- /dev/null +++ b/pdns/dnsdistdist/m4/pdns_with_gnutls.m4 @@ -0,0 +1 @@ +../../../m4/pdns_with_gnutls.m4 \ No newline at end of file diff --git a/pdns/dnsdistdist/m4/pdns_with_libssl.m4 b/pdns/dnsdistdist/m4/pdns_with_libssl.m4 new file mode 120000 index 0000000000..7a8a381775 --- /dev/null +++ b/pdns/dnsdistdist/m4/pdns_with_libssl.m4 @@ -0,0 +1 @@ +../../../m4/pdns_with_libssl.m4 \ No newline at end of file diff --git a/pdns/dnsdistdist/tcpiohandler.cc b/pdns/dnsdistdist/tcpiohandler.cc deleted file mode 100644 index e308e2791b..0000000000 --- a/pdns/dnsdistdist/tcpiohandler.cc +++ /dev/null @@ -1,926 +0,0 @@ - -#include "config.h" -#include "dolog.hh" -#include "iputils.hh" -#include "lock.hh" -#include "tcpiohandler.hh" - -#ifdef HAVE_LIBSODIUM -#include -#endif /* HAVE_LIBSODIUM */ - -#ifdef HAVE_DNS_OVER_TLS -#ifdef HAVE_LIBSSL -#include -#include -#include -#include - -#include "libssl.hh" - -class OpenSSLFrontendContext -{ -public: - OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys) - { - registerOpenSSLUser(); - - d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses); - if (!d_tlsCtx) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort()); - } - } - - void cleanup() - { - d_tlsCtx.reset(); - - unregisterOpenSSLUser(); - } - - OpenSSLTLSTicketKeysRing d_ticketKeys; - std::map d_ocspResponses; - std::unique_ptr d_tlsCtx{nullptr, SSL_CTX_free}; - std::unique_ptr d_keyLogFile{nullptr, fclose}; -}; - -class OpenSSLTLSConnection: public TLSConnection -{ -public: - OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr feContext): d_feContext(feContext), d_conn(std::unique_ptr(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) - { - d_socket = socket; - - if (!s_initTLSConnIndex.test_and_set()) { - /* not initialized yet */ - s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - if (s_tlsConnIndex == -1) { - throw std::runtime_error("Error getting an index for TLS connection data"); - } - } - - if (!d_conn) { - vinfolog("Error creating TLS object"); - if (g_verbose) { - ERR_print_errors_fp(stderr); - } - throw std::runtime_error("Error creating TLS object"); - } - - if (!SSL_set_fd(d_conn.get(), d_socket)) { - throw std::runtime_error("Error assigning socket"); - } - - SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this); - } - - IOState convertIORequestToIOState(int res) const - { - int error = SSL_get_error(d_conn.get(), res); - if (error == SSL_ERROR_WANT_READ) { - return IOState::NeedRead; - } - else if (error == SSL_ERROR_WANT_WRITE) { - return IOState::NeedWrite; - } - else if (error == SSL_ERROR_SYSCALL) { - throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno))); - } - else { - throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error)); - } - } - - void handleIORequest(int res, unsigned int timeout) - { - auto state = convertIORequestToIOState(res); - if (state == IOState::NeedRead) { - res = waitForData(d_socket, timeout); - if (res == 0) { - throw std::runtime_error("Timeout while reading from TLS connection"); - } - else if (res < 0) { - throw std::runtime_error("Error waiting to read from TLS connection"); - } - } - else if (state == IOState::NeedWrite) { - res = waitForRWData(d_socket, false, timeout, 0); - if (res == 0) { - throw std::runtime_error("Timeout while writing to TLS connection"); - } - else if (res < 0) { - throw std::runtime_error("Error waiting to write to TLS connection"); - } - } - } - - IOState tryHandshake() override - { - int res = SSL_accept(d_conn.get()); - if (res == 1) { - return IOState::Done; - } - else if (res < 0) { - return convertIORequestToIOState(res); - } - - throw std::runtime_error("Error accepting TLS connection"); - } - - void doHandshake() override - { - int res = 0; - do { - res = SSL_accept(d_conn.get()); - if (res < 0) { - handleIORequest(res, d_timeout); - } - } - while (res < 0); - - if (res != 1) { - throw std::runtime_error("Error accepting TLS connection"); - } - } - - IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override - { - do { - int res = SSL_write(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toWrite - pos)); - if (res <= 0) { - return convertIORequestToIOState(res); - } - else { - pos += static_cast(res); - } - } - while (pos < toWrite); - return IOState::Done; - } - - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override - { - do { - int res = SSL_read(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toRead - pos)); - if (res <= 0) { - return convertIORequestToIOState(res); - } - else { - pos += static_cast(res); - } - } - while (pos < toRead); - return IOState::Done; - } - - size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override - { - size_t got = 0; - time_t start = 0; - unsigned int remainingTime = totalTimeout; - if (totalTimeout) { - start = time(nullptr); - } - - do { - int res = SSL_read(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); - if (res <= 0) { - handleIORequest(res, readTimeout); - } - else { - got += static_cast(res); - } - - if (totalTimeout) { - time_t now = time(nullptr); - unsigned int elapsed = now - start; - if (now < start || elapsed >= remainingTime) { - throw runtime_error("Timeout while reading data"); - } - start = now; - remainingTime -= elapsed; - } - } - while (got < bufferSize); - - return got; - } - - size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override - { - size_t got = 0; - do { - int res = SSL_write(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); - if (res <= 0) { - handleIORequest(res, writeTimeout); - } - else { - got += static_cast(res); - } - } - while (got < bufferSize); - - return got; - } - - bool hasBufferedData() const override - { - if (d_conn) { - return SSL_pending(d_conn.get()) > 0; - } - - return false; - } - - void close() override - { - if (d_conn) { - SSL_shutdown(d_conn.get()); - } - } - - std::string getServerNameIndication() const override - { - if (d_conn) { - const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name); - if (value) { - return std::string(value); - } - } - return std::string(); - } - - LibsslTLSVersion getTLSVersion() const override - { - auto proto = SSL_version(d_conn.get()); - switch (proto) { - case TLS1_VERSION: - return LibsslTLSVersion::TLS10; - case TLS1_1_VERSION: - return LibsslTLSVersion::TLS11; - case TLS1_2_VERSION: - return LibsslTLSVersion::TLS12; -#ifdef TLS1_3_VERSION - case TLS1_3_VERSION: - return LibsslTLSVersion::TLS13; -#endif /* TLS1_3_VERSION */ - default: - return LibsslTLSVersion::Unknown; - } - } - - bool hasSessionBeenResumed() const override - { - if (d_conn) { - return SSL_session_reused(d_conn.get()) != 0; - } - return false; - } - - static int s_tlsConnIndex; - -private: - static std::atomic_flag s_initTLSConnIndex; - - std::shared_ptr d_feContext; - std::unique_ptr d_conn; - unsigned int d_timeout; -}; - -std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT; -int OpenSSLTLSConnection::s_tlsConnIndex = -1; - -class OpenSSLTLSIOCtx: public TLSCtx -{ -public: - OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared(fe.d_addr, fe.d_tlsConfig)) - { - d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; - - if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) { - /* use our own ticket keys handler so we can rotate them */ - SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); - libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get()); - } - - if (!d_feContext->d_ocspResponses.empty()) { - SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); - SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses); - } - - libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); - - if (!fe.d_tlsConfig.d_keyLogFile.empty()) { - d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile); - } - - try { - if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { - handleTicketsKeyRotation(time(nullptr)); - } - else { - OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); - } - } - catch (const std::exception& e) { - throw; - } - } - - ~OpenSSLTLSIOCtx() override - { - } - - static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) - { - OpenSSLFrontendContext* ctx = reinterpret_cast(libssl_get_ticket_key_callback_data(s)); - if (ctx == nullptr) { - return -1; - } - - int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc); - if (enc == 0) { - if (ret == 0 || ret == 2) { - OpenSSLTLSConnection* conn = reinterpret_cast(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex)); - if (conn) { - if (ret == 0) { - conn->setUnknownTicketKey(); - } - else if (ret == 2) { - conn->setResumedFromInactiveTicketKey(); - } - } - } - } - - return ret; - } - - static int ocspStaplingCb(SSL* ssl, void* arg) - { - if (ssl == nullptr || arg == nullptr) { - return SSL_TLSEXT_ERR_NOACK; - } - const auto ocspMap = reinterpret_cast*>(arg); - return libssl_ocsp_stapling_callback(ssl, *ocspMap); - } - - std::unique_ptr getConnection(int socket, unsigned int timeout, time_t now) override - { - handleTicketsKeyRotation(now); - - return std::unique_ptr(new OpenSSLTLSConnection(socket, timeout, d_feContext)); - } - - void rotateTicketsKey(time_t now) override - { - d_feContext->d_ticketKeys.rotateTicketsKey(now); - - if (d_ticketsKeyRotationDelay > 0) { - d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; - } - } - - void loadTicketsKeys(const std::string& keyFile) override final - { - d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); - - if (d_ticketsKeyRotationDelay > 0) { - d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; - } - } - - size_t getTicketsKeysCount() override - { - return d_feContext->d_ticketKeys.getKeysCount(); - } - -private: - std::shared_ptr d_feContext; -}; - -#endif /* HAVE_LIBSSL */ - -#ifdef HAVE_GNUTLS -#include -#include - -static void safe_memory_lock(void* data, size_t size) -{ -#ifdef HAVE_LIBSODIUM - sodium_mlock(data, size); -#endif -} - -static void safe_memory_release(void* data, size_t size) -{ -#ifdef HAVE_LIBSODIUM - sodium_munlock(data, size); -#elif defined(HAVE_EXPLICIT_BZERO) - explicit_bzero(data, size); -#elif defined(HAVE_EXPLICIT_MEMSET) - explicit_memset(data, 0, size); -#elif defined(HAVE_GNUTLS_MEMSET) - gnutls_memset(data, 0, size); -#else - /* shamelessly taken from Dovecot's src/lib/safe-memset.c */ - volatile unsigned int volatile_zero_idx = 0; - volatile unsigned char *p = reinterpret_cast(data); - - if (size == 0) - return; - - do { - memset(data, 0, size); - } while (p[volatile_zero_idx] != 0); -#endif -} - -class GnuTLSTicketsKey -{ -public: - GnuTLSTicketsKey() - { - if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error generating tickets key for TLS context"); - } - - safe_memory_lock(d_key.data, d_key.size); - } - - GnuTLSTicketsKey(const std::string& keyFile) - { - /* to be sure we are loading the correct amount of data, which - may change between versions, let's generate a correct key first */ - if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context"); - } - - safe_memory_lock(d_key.data, d_key.size); - - try { - ifstream file(keyFile); - file.read(reinterpret_cast(d_key.data), d_key.size); - - if (file.fail()) { - file.close(); - throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile); - } - - file.close(); - } - catch (const std::exception& e) { - safe_memory_release(d_key.data, d_key.size); - gnutls_free(d_key.data); - d_key.data = nullptr; - throw; - } - } - - ~GnuTLSTicketsKey() - { - if (d_key.data != nullptr && d_key.size > 0) { - safe_memory_release(d_key.data, d_key.size); - } - gnutls_free(d_key.data); - d_key.data = nullptr; - } - const gnutls_datum_t& getKey() const - { - return d_key; - } - -private: - gnutls_datum_t d_key{nullptr, 0}; -}; - -class GnuTLSConnection: public TLSConnection -{ -public: - - GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr& ticketsKey, bool enableTickets): d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey) - { - unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK; -#ifdef GNUTLS_NO_SIGNAL - sslOptions |= GNUTLS_NO_SIGNAL; -#endif - - d_socket = socket; - - gnutls_session_t conn; - if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error creating TLS connection"); - } - - d_conn = std::unique_ptr(conn, gnutls_deinit); - conn = nullptr; - - if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error setting certificate and key to TLS connection"); - } - - if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error setting ciphers to TLS connection"); - } - - if (enableTickets && d_ticketsKey) { - const gnutls_datum_t& key = d_ticketsKey->getKey(); - if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error setting the tickets key to TLS connection"); - } - } - - gnutls_transport_set_int(d_conn.get(), d_socket); - - /* timeouts are in milliseconds */ - gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000); - gnutls_record_set_timeout(d_conn.get(), timeout * 1000); - } - - void doHandshake() override - { - int ret = 0; - do { - ret = gnutls_handshake(d_conn.get()); - if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { - throw std::runtime_error("Error accepting a new connection"); - } - } - while (ret < 0 && ret == GNUTLS_E_INTERRUPTED); - } - - IOState tryHandshake() override - { - int ret = 0; - - do { - ret = gnutls_handshake(d_conn.get()); - if (ret == GNUTLS_E_SUCCESS) { - return IOState::Done; - } - else if (ret == GNUTLS_E_AGAIN) { - return IOState::NeedRead; - } - else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { - throw std::runtime_error("Error accepting a new connection"); - } - } while (ret == GNUTLS_E_INTERRUPTED); - - throw std::runtime_error("Error accepting a new connection"); - } - - IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override - { - do { - ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toWrite - pos); - if (res == 0) { - throw std::runtime_error("Error writing to TLS connection"); - } - else if (res > 0) { - pos += static_cast(res); - } - else if (res < 0) { - if (gnutls_error_is_fatal(res)) { - throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); - } - else if (res == GNUTLS_E_AGAIN) { - return IOState::NeedWrite; - } - warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); - } - } - while (pos < toWrite); - return IOState::Done; - } - - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override - { - do { - ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toRead - pos); - if (res == 0) { - throw std::runtime_error("Error reading from TLS connection"); - } - else if (res > 0) { - pos += static_cast(res); - } - else if (res < 0) { - if (gnutls_error_is_fatal(res)) { - throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); - } - else if (res == GNUTLS_E_AGAIN) { - return IOState::NeedRead; - } - warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); - } - } - while (pos < toRead); - return IOState::Done; - } - - size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override - { - size_t got = 0; - time_t start = 0; - unsigned int remainingTime = totalTimeout; - if (totalTimeout) { - start = time(nullptr); - } - - do { - ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); - if (res == 0) { - throw std::runtime_error("Error reading from TLS connection"); - } - else if (res > 0) { - got += static_cast(res); - } - else if (res < 0) { - if (gnutls_error_is_fatal(res)) { - throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); - } - else if (res == GNUTLS_E_AGAIN) { - int result = waitForData(d_socket, readTimeout); - if (result <= 0) { - throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result)); - } - } - else { - vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res)); - } - } - - if (totalTimeout) { - time_t now = time(nullptr); - unsigned int elapsed = now - start; - if (now < start || elapsed >= remainingTime) { - throw runtime_error("Timeout while reading data"); - } - start = now; - remainingTime -= elapsed; - } - } - while (got < bufferSize); - - return got; - } - - size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override - { - size_t got = 0; - - do { - ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); - if (res == 0) { - throw std::runtime_error("Error writing to TLS connection"); - } - else if (res > 0) { - got += static_cast(res); - } - else if (res < 0) { - if (gnutls_error_is_fatal(res)) { - throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); - } - else if (res == GNUTLS_E_AGAIN) { - int result = waitForRWData(d_socket, false, writeTimeout, 0); - if (result <= 0) { - throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result)); - } - } - else { - vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); - } - } - } - while (got < bufferSize); - - return got; - } - - bool hasBufferedData() const override - { - if (d_conn) { - return gnutls_record_check_pending(d_conn.get()) > 0; - } - - return false; - } - - std::string getServerNameIndication() const override - { - if (d_conn) { - unsigned int type; - size_t name_len = 256; - std::string sni; - sni.resize(name_len); - - int res = gnutls_server_name_get(d_conn.get(), const_cast(sni.c_str()), &name_len, &type, 0); - if (res == GNUTLS_E_SUCCESS) { - sni.resize(name_len); - return sni; - } - } - return std::string(); - } - - LibsslTLSVersion getTLSVersion() const override - { - auto proto = gnutls_protocol_get_version(d_conn.get()); - switch (proto) { - case GNUTLS_TLS1_0: - return LibsslTLSVersion::TLS10; - case GNUTLS_TLS1_1: - return LibsslTLSVersion::TLS11; - case GNUTLS_TLS1_2: - return LibsslTLSVersion::TLS12; -#if GNUTLS_VERSION_NUMBER >= 0x030603 - case GNUTLS_TLS1_3: - return LibsslTLSVersion::TLS13; -#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */ - default: - return LibsslTLSVersion::Unknown; - } - } - - bool hasSessionBeenResumed() const override - { - if (d_conn) { - return gnutls_session_is_resumed(d_conn.get()) != 0; - } - return false; - } - - void close() override - { - if (d_conn) { - gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR); - } - } - -private: - std::unique_ptr d_conn; - std::shared_ptr d_ticketsKey; -}; - -class GnuTLSIOCtx: public TLSCtx -{ -public: - GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets) - { - int rc = 0; - d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; - - gnutls_certificate_credentials_t creds; - rc = gnutls_certificate_allocate_credentials(&creds); - if (rc != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); - } - - d_creds = std::unique_ptr(creds, gnutls_certificate_free_credentials); - creds = nullptr; - - for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) { - rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM); - if (rc != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); - } - } - - size_t count = 0; - for (const auto& file : fe.d_tlsConfig.d_ocspFiles) { - rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count); - if (rc != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); - } - ++count; - } - -#if GNUTLS_VERSION_NUMBER >= 0x030600 - rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH); - if (rc != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); - } -#endif - - rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr); - if (rc != GNUTLS_E_SUCCESS) { - throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort()); - } - - try { - if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { - handleTicketsKeyRotation(time(nullptr)); - } - else { - GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); - } - } - catch(const std::runtime_error& e) { - throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what()); - } - } - - virtual ~GnuTLSIOCtx() override - { - d_creds.reset(); - - if (d_priorityCache) { - gnutls_priority_deinit(d_priorityCache); - } - } - - std::unique_ptr getConnection(int socket, unsigned int timeout, time_t now) override - { - handleTicketsKeyRotation(now); - - std::shared_ptr ticketsKey; - { - ReadLock rl(&d_lock); - ticketsKey = d_ticketsKey; - } - - return std::unique_ptr(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets)); - } - - void rotateTicketsKey(time_t now) override - { - if (!d_enableTickets) { - return; - } - - auto newKey = std::make_shared(); - - { - WriteLock wl(&d_lock); - d_ticketsKey = newKey; - } - - if (d_ticketsKeyRotationDelay > 0) { - d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; - } - } - - void loadTicketsKeys(const std::string& file) override final - { - if (!d_enableTickets) { - return; - } - - auto newKey = std::make_shared(file); - { - WriteLock wl(&d_lock); - d_ticketsKey = newKey; - } - - if (d_ticketsKeyRotationDelay > 0) { - d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; - } - } - - size_t getTicketsKeysCount() override - { - ReadLock rl(&d_lock); - return d_ticketsKey != nullptr ? 1 : 0; - } - -private: - std::unique_ptr d_creds; - gnutls_priority_t d_priorityCache{nullptr}; - std::shared_ptr d_ticketsKey{nullptr}; - ReadWriteLock d_lock; - bool d_enableTickets{true}; -}; - -#endif /* HAVE_GNUTLS */ - -#endif /* HAVE_DNS_OVER_TLS */ - -bool TLSFrontend::setupTLS() -{ -#ifdef HAVE_DNS_OVER_TLS - /* get the "best" available provider */ - if (!d_provider.empty()) { -#ifdef HAVE_GNUTLS - if (d_provider == "gnutls") { - d_ctx = std::make_shared(*this); - return true; - } -#endif /* HAVE_GNUTLS */ -#ifdef HAVE_LIBSSL - if (d_provider == "openssl") { - d_ctx = std::make_shared(*this); - return true; - } -#endif /* HAVE_LIBSSL */ - } -#ifdef HAVE_LIBSSL - d_ctx = std::make_shared(*this); -#else /* HAVE_LIBSSL */ -#ifdef HAVE_GNUTLS - d_ctx = std::make_shared(*this); -#endif /* HAVE_GNUTLS */ -#endif /* HAVE_LIBSSL */ - -#endif /* HAVE_DNS_OVER_TLS */ - return true; -} diff --git a/pdns/dnsdistdist/tcpiohandler.cc b/pdns/dnsdistdist/tcpiohandler.cc new file mode 120000 index 0000000000..a583875518 --- /dev/null +++ b/pdns/dnsdistdist/tcpiohandler.cc @@ -0,0 +1 @@ +../tcpiohandler.cc \ No newline at end of file diff --git a/pdns/libssl.cc b/pdns/libssl.cc new file mode 100644 index 0000000000..deffcdbf49 --- /dev/null +++ b/pdns/libssl.cc @@ -0,0 +1,785 @@ + +#include "config.h" +#include "libssl.hh" + +#ifdef HAVE_LIBSSL + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifdef HAVE_LIBSODIUM +#include +#endif /* HAVE_LIBSODIUM */ + +#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) +/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */ + +#include "lock.hh" +static std::vector openssllocks; + +extern "C" { +static void openssl_pthreads_locking_callback(int mode, int type, const char *file, int line) +{ + if (mode & CRYPTO_LOCK) { + openssllocks.at(type).lock(); + + } else { + openssllocks.at(type).unlock(); + } +} + +static unsigned long openssl_pthreads_id_callback() +{ + return (unsigned long)pthread_self(); +} +} + +static void openssl_thread_setup() +{ + openssllocks = std::vector(CRYPTO_num_locks()); + CRYPTO_set_id_callback(&openssl_pthreads_id_callback); + CRYPTO_set_locking_callback(&openssl_pthreads_locking_callback); +} + +static void openssl_thread_cleanup() +{ + CRYPTO_set_locking_callback(nullptr); + openssllocks.clear(); +} + +#endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) */ + +static std::atomic s_users; +static int s_ticketsKeyIndex{-1}; +static int s_countersIndex{-1}; +static int s_keyLogIndex{-1}; + +void registerOpenSSLUser() +{ + if (s_users.fetch_add(1) == 0) { +#ifdef HAVE_OPENSSL_INIT_CRYPTO + /* load the default configuration file (or one specified via OPENSSL_CONF), + which can then be used to load engines */ + OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CONFIG, nullptr); +#endif + +#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL)) + SSL_load_error_strings(); + OpenSSL_add_ssl_algorithms(); + openssl_thread_setup(); +#endif + s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + + if (s_ticketsKeyIndex == -1) { + throw std::runtime_error("Error getting an index for tickets key"); + } + + s_countersIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + + if (s_countersIndex == -1) { + throw std::runtime_error("Error getting an index for counters"); + } + + s_keyLogIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + + if (s_keyLogIndex == -1) { + throw std::runtime_error("Error getting an index for TLS key logging"); + } + } +} + +void unregisterOpenSSLUser() +{ + if (s_users.fetch_sub(1) == 1) { +#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL)) + ERR_free_strings(); + + EVP_cleanup(); + + CONF_modules_finish(); + CONF_modules_free(); + CONF_modules_unload(1); + + CRYPTO_cleanup_all_ex_data(); + openssl_thread_cleanup(); +#endif + } +} + +void* libssl_get_ticket_key_callback_data(SSL* s) +{ + SSL_CTX* sslCtx = SSL_get_SSL_CTX(s); + if (sslCtx == nullptr) { + return nullptr; + } + + return SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex); +} + +void libssl_set_ticket_key_callback_data(SSL_CTX* ctx, void* data) +{ + SSL_CTX_set_ex_data(ctx, s_ticketsKeyIndex, data); +} + +int libssl_ticket_key_callback(SSL *s, OpenSSLTLSTicketKeysRing& keyring, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) +{ + if (enc) { + const auto key = keyring.getEncryptionKey(); + if (key == nullptr) { + return -1; + } + + return key->encrypt(keyName, iv, ectx, hctx); + } + + bool activeEncryptionKey = false; + + const auto key = keyring.getDecryptionKey(keyName, activeEncryptionKey); + if (key == nullptr) { + /* we don't know this key, just create a new ticket */ + return 0; + } + + if (key->decrypt(iv, ectx, hctx) == false) { + return -1; + } + + if (!activeEncryptionKey) { + /* this key is not active, please encrypt the ticket content with the currently active one */ + return 2; + } + + return 1; +} + +static long libssl_server_name_callback(SSL* ssl, int* al, void* arg) +{ + (void) al; + (void) arg; + + if (SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name)) { + return SSL_TLSEXT_ERR_OK; + } + + return SSL_TLSEXT_ERR_NOACK; +} + +static void libssl_info_callback(const SSL *ssl, int where, int ret) +{ + SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl); + if (sslCtx == nullptr) { + return; + } + + TLSErrorCounters* counters = reinterpret_cast(SSL_CTX_get_ex_data(sslCtx, s_countersIndex)); + if (counters == nullptr) { + return; + } + + if (where & SSL_CB_ALERT) { + const long lastError = ERR_peek_last_error(); + switch (ERR_GET_REASON(lastError)) { +#ifdef SSL_R_DH_KEY_TOO_SMALL + case SSL_R_DH_KEY_TOO_SMALL: + ++counters->d_dhKeyTooSmall; + break; +#endif /* SSL_R_DH_KEY_TOO_SMALL */ + case SSL_R_NO_SHARED_CIPHER: + ++counters->d_noSharedCipher; + break; + case SSL_R_UNKNOWN_PROTOCOL: + ++counters->d_unknownProtocol; + break; + case SSL_R_UNSUPPORTED_PROTOCOL: +#ifdef SSL_R_VERSION_TOO_LOW + case SSL_R_VERSION_TOO_LOW: +#endif /* SSL_R_VERSION_TOO_LOW */ + ++counters->d_unsupportedProtocol; + break; + case SSL_R_INAPPROPRIATE_FALLBACK: + ++counters->d_inappropriateFallBack; + break; + case SSL_R_UNKNOWN_CIPHER_TYPE: + ++counters->d_unknownCipherType; + break; + case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE: + ++counters->d_unknownKeyExchangeType; + break; + case SSL_R_UNSUPPORTED_ELLIPTIC_CURVE: + ++counters->d_unsupportedEC; + break; + default: + break; + } + } +} + +void libssl_set_error_counters_callback(std::unique_ptr& ctx, TLSErrorCounters* counters) +{ + SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters); + SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback); +} + +int libssl_ocsp_stapling_callback(SSL* ssl, const std::map& ocspMap) +{ + auto pkey = SSL_get_privatekey(ssl); + if (pkey == nullptr) { + return SSL_TLSEXT_ERR_NOACK; + } + + /* look for an OCSP response for the corresponding private key type (RSA, ECDSA..) */ + const auto& data = ocspMap.find(EVP_PKEY_base_id(pkey)); + if (data == ocspMap.end()) { + return SSL_TLSEXT_ERR_NOACK; + } + + /* we need to allocate a copy because OpenSSL will free the pointer passed to SSL_set_tlsext_status_ocsp_resp() */ + void* copy = OPENSSL_malloc(data->second.size()); + if (copy == nullptr) { + return SSL_TLSEXT_ERR_NOACK; + } + + memcpy(copy, data->second.data(), data->second.size()); + SSL_set_tlsext_status_ocsp_resp(ssl, copy, data->second.size()); + return SSL_TLSEXT_ERR_OK; +} + +static bool libssl_validate_ocsp_response(const std::string& response) +{ + auto responsePtr = reinterpret_cast(response.data()); + std::unique_ptr resp(d2i_OCSP_RESPONSE(nullptr, &responsePtr, response.size()), OCSP_RESPONSE_free); + if (resp == nullptr) { + throw std::runtime_error("Unable to parse OCSP response"); + } + + int status = OCSP_response_status(resp.get()); + if (status != OCSP_RESPONSE_STATUS_SUCCESSFUL) { + throw std::runtime_error("OCSP response status is not successful: " + std::to_string(status)); + } + + std::unique_ptr basic(OCSP_response_get1_basic(resp.get()), OCSP_BASICRESP_free); + if (basic == nullptr) { + throw std::runtime_error("Error getting a basic OCSP response"); + } + + if (OCSP_resp_count(basic.get()) != 1) { + throw std::runtime_error("More than one single response in an OCSP basic response"); + } + + auto singleResponse = OCSP_resp_get0(basic.get(), 0); + if (singleResponse == nullptr) { + throw std::runtime_error("Error getting a single response from the basic OCSP response"); + } + + int reason; + ASN1_GENERALIZEDTIME* revTime = nullptr; + ASN1_GENERALIZEDTIME* thisUpdate = nullptr; + ASN1_GENERALIZEDTIME* nextUpdate = nullptr; + + auto singleResponseStatus = OCSP_single_get0_status(singleResponse, &reason, &revTime, &thisUpdate, &nextUpdate); + if (singleResponseStatus != V_OCSP_CERTSTATUS_GOOD) { + throw std::runtime_error("Invalid status for OCSP single response (" + std::to_string(singleResponseStatus) + ")"); + } + if (thisUpdate == nullptr || nextUpdate == nullptr) { + throw std::runtime_error("Error getting validity of OCSP single response"); + } + + auto validityResult = OCSP_check_validity(thisUpdate, nextUpdate, /* 5 minutes of leeway */ 5 * 60, -1); + if (validityResult == 0) { + throw std::runtime_error("OCSP single response is not yet, or no longer, valid"); + } + + return true; +} + +std::map libssl_load_ocsp_responses(const std::vector& ocspFiles, std::vector keyTypes) +{ + std::map ocspResponses; + + if (ocspFiles.size() > keyTypes.size()) { + throw std::runtime_error("More OCSP files than certificates and keys loaded!"); + } + + size_t count = 0; + for (const auto& filename : ocspFiles) { + std::ifstream file(filename, std::ios::binary); + std::string content; + while(file) { + char buffer[4096]; + file.read(buffer, sizeof(buffer)); + if (file.bad()) { + file.close(); + throw std::runtime_error("Unable to load OCSP response from '" + filename + "'"); + } + content.append(buffer, file.gcount()); + } + file.close(); + + try { + libssl_validate_ocsp_response(content); + ocspResponses.insert({keyTypes.at(count), std::move(content)}); + } + catch (const std::exception& e) { + throw std::runtime_error("Error checking the validity of OCSP response from '" + filename + "': " + e.what()); + } + ++count; + } + + return ocspResponses; +} + +int libssl_get_last_key_type(std::unique_ptr& ctx) +{ +#ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY + auto pkey = SSL_CTX_get0_privatekey(ctx.get()); +#else + auto temp = std::unique_ptr(SSL_new(ctx.get()), SSL_free); + if (!temp) { + return -1; + } + auto pkey = SSL_get_privatekey(temp.get()); +#endif + + if (!pkey) { + return -1; + } + + return EVP_PKEY_base_id(pkey); +} + +#ifdef HAVE_OCSP_BASIC_SIGN +bool libssl_generate_ocsp_response(const std::string& certFile, const std::string& caCert, const std::string& caKey, const std::string& outFile, int ndays, int nmin) +{ + const EVP_MD* rmd = EVP_sha256(); + + auto fp = std::unique_ptr(fopen(certFile.c_str(), "r"), fclose); + if (!fp) { + throw std::runtime_error("Unable to open '" + certFile + "' when loading the certificate to generate an OCSP response"); + } + auto cert = std::unique_ptr(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free); + + fp = std::unique_ptr(fopen(caCert.c_str(), "r"), fclose); + if (!fp) { + throw std::runtime_error("Unable to open '" + caCert + "' when loading the issuer certificate to generate an OCSP response"); + } + auto issuer = std::unique_ptr(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free); + fp = std::unique_ptr(fopen(caKey.c_str(), "r"), fclose); + if (!fp) { + throw std::runtime_error("Unable to open '" + caKey + "' when loading the issuer key to generate an OCSP response"); + } + auto issuerKey = std::unique_ptr(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr), EVP_PKEY_free); + fp.reset(); + + auto bs = std::unique_ptr(OCSP_BASICRESP_new(), OCSP_BASICRESP_free); + auto thisupd = std::unique_ptr(X509_gmtime_adj(nullptr, 0), ASN1_TIME_free); + auto nextupd = std::unique_ptr(X509_time_adj_ex(nullptr, ndays, nmin * 60, nullptr), ASN1_TIME_free); + + auto cid = std::unique_ptr(OCSP_cert_to_id(rmd, cert.get(), issuer.get()), OCSP_CERTID_free); + OCSP_basic_add1_status(bs.get(), cid.get(), V_OCSP_CERTSTATUS_GOOD, 0, nullptr, thisupd.get(), nextupd.get()); + + if (OCSP_basic_sign(bs.get(), issuer.get(), issuerKey.get(), rmd, nullptr, OCSP_NOCERTS) != 1) { + throw std::runtime_error("Error while signing the OCSP response"); + } + + auto resp = std::unique_ptr(OCSP_response_create(OCSP_RESPONSE_STATUS_SUCCESSFUL, bs.get()), OCSP_RESPONSE_free); + auto bio = std::unique_ptr(BIO_new_file(outFile.c_str(), "wb"), BIO_vfree); + if (!bio) { + throw std::runtime_error("Error opening file for writing the OCSP response"); + } + + // i2d_OCSP_RESPONSE_bio(bio.get(), resp.get()) is unusable from C++ because of an invalid cast + ASN1_i2d_bio((i2d_of_void*)i2d_OCSP_RESPONSE, bio.get(), (unsigned char*)resp.get()); + + return true; +} +#endif /* HAVE_OCSP_BASIC_SIGN */ + +LibsslTLSVersion libssl_tls_version_from_string(const std::string& str) +{ + if (str == "tls1.0") { + return LibsslTLSVersion::TLS10; + } + if (str == "tls1.1") { + return LibsslTLSVersion::TLS11; + } + if (str == "tls1.2") { + return LibsslTLSVersion::TLS12; + } + if (str == "tls1.3") { + return LibsslTLSVersion::TLS13; + } + throw std::runtime_error("Unknown TLS version '" + str); +} + +const std::string& libssl_tls_version_to_string(LibsslTLSVersion version) +{ + static const std::map versions = { + { LibsslTLSVersion::TLS10, "tls1.0" }, + { LibsslTLSVersion::TLS11, "tls1.1" }, + { LibsslTLSVersion::TLS12, "tls1.2" }, + { LibsslTLSVersion::TLS13, "tls1.3" } + }; + + const auto& it = versions.find(version); + if (it == versions.end()) { + throw std::runtime_error("Unknown TLS version (" + std::to_string((int)version) + ")"); + } + return it->second; +} + +bool libssl_set_min_tls_version(std::unique_ptr& ctx, LibsslTLSVersion version) +{ +#if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version) + /* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated + Warning: SSL_CTX_set_min_proto_version is a function-like macro in OpenSSL */ + int vers; + switch(version) { + case LibsslTLSVersion::TLS10: + vers = TLS1_VERSION; + break; + case LibsslTLSVersion::TLS11: + vers = TLS1_1_VERSION; + break; + case LibsslTLSVersion::TLS12: + vers = TLS1_2_VERSION; + break; + case LibsslTLSVersion::TLS13: +#ifdef TLS1_3_VERSION + vers = TLS1_3_VERSION; +#else + return false; +#endif /* TLS1_3_VERSION */ + break; + default: + return false; + } + + if (SSL_CTX_set_min_proto_version(ctx.get(), vers) != 1) { + return false; + } + return true; +#else + long vers = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3; + switch(version) { + case LibsslTLSVersion::TLS10: + break; + case LibsslTLSVersion::TLS11: + vers |= SSL_OP_NO_TLSv1; + break; + case LibsslTLSVersion::TLS12: + vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1; + break; + case LibsslTLSVersion::TLS13: + vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2; + break; + default: + return false; + } + + long options = SSL_CTX_get_options(ctx.get()); + SSL_CTX_set_options(ctx.get(), options | vers); + return true; +#endif +} + +OpenSSLTLSTicketKeysRing::OpenSSLTLSTicketKeysRing(size_t capacity) +{ + d_ticketKeys.set_capacity(capacity); +} + +OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() +{ +} + +void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr newKey) +{ + WriteLock wl(&d_lock); + d_ticketKeys.push_front(newKey); +} + +std::shared_ptr OpenSSLTLSTicketKeysRing::getEncryptionKey() +{ + ReadLock rl(&d_lock); + return d_ticketKeys.front(); +} + +std::shared_ptr OpenSSLTLSTicketKeysRing::getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey) +{ + ReadLock rl(&d_lock); + for (auto& key : d_ticketKeys) { + if (key->nameMatches(name)) { + activeKey = (key == d_ticketKeys.front()); + return key; + } + } + return nullptr; +} + +size_t OpenSSLTLSTicketKeysRing::getKeysCount() +{ + ReadLock rl(&d_lock); + return d_ticketKeys.size(); +} + +void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile) +{ + bool keyLoaded = false; + std::ifstream file(keyFile); + try { + do { + auto newKey = std::make_shared(file); + addKey(newKey); + keyLoaded = true; + } + while (!file.fail()); + } + catch (const std::exception& e) { + /* if we haven't been able to load at least one key, fail */ + if (!keyLoaded) { + throw; + } + } + + file.close(); +} + +void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t now) +{ + auto newKey = std::make_shared(); + addKey(newKey); +} + +OpenSSLTLSTicketKey::OpenSSLTLSTicketKey() +{ + if (RAND_bytes(d_name, sizeof(d_name)) != 1) { + throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key"); + } + + if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) { + throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key"); + } + + if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) { + throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key"); + } +#ifdef HAVE_LIBSODIUM + sodium_mlock(d_name, sizeof(d_name)); + sodium_mlock(d_cipherKey, sizeof(d_cipherKey)); + sodium_mlock(d_hmacKey, sizeof(d_hmacKey)); +#endif /* HAVE_LIBSODIUM */ +} + +OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(ifstream& file) +{ + file.read(reinterpret_cast(d_name), sizeof(d_name)); + file.read(reinterpret_cast(d_cipherKey), sizeof(d_cipherKey)); + file.read(reinterpret_cast(d_hmacKey), sizeof(d_hmacKey)); + + if (file.fail()) { + throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file"); + } +#ifdef HAVE_LIBSODIUM + sodium_mlock(d_name, sizeof(d_name)); + sodium_mlock(d_cipherKey, sizeof(d_cipherKey)); + sodium_mlock(d_hmacKey, sizeof(d_hmacKey)); +#endif /* HAVE_LIBSODIUM */ +} + +OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey() +{ +#ifdef HAVE_LIBSODIUM + sodium_munlock(d_name, sizeof(d_name)); + sodium_munlock(d_cipherKey, sizeof(d_cipherKey)); + sodium_munlock(d_hmacKey, sizeof(d_hmacKey)); +#else + OPENSSL_cleanse(d_name, sizeof(d_name)); + OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey)); + OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey)); +#endif /* HAVE_LIBSODIUM */ +} + +bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const +{ + return (memcmp(d_name, name, sizeof(d_name)) == 0); +} + +int OpenSSLTLSTicketKey::encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const +{ + memcpy(keyName, d_name, sizeof(d_name)); + + if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) { + return -1; + } + + if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) { + return -1; + } + + if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) { + return -1; + } + + return 1; +} + +bool OpenSSLTLSTicketKey::decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const +{ + if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) { + return false; + } + + if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) { + return false; + } + + return true; +} + +std::unique_ptr libssl_init_server_context(const TLSConfig& config, + std::map& ocspResponses) +{ + auto ctx = std::unique_ptr(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free); + + int sslOptions = + SSL_OP_NO_SSLv2 | + SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | + SSL_OP_SINGLE_DH_USE | + SSL_OP_SINGLE_ECDH_USE; + + if (!config.d_enableTickets || config.d_numberOfTicketsKeys == 0) { + /* for TLS 1.3 this means no stateless tickets, but stateful tickets might still be issued, + which is something we don't want. */ + sslOptions |= SSL_OP_NO_TICKET; + /* really disable all tickets */ +#ifdef HAVE_SSL_CTX_SET_NUM_TICKETS + SSL_CTX_set_num_tickets(ctx.get(), 0); +#endif /* HAVE_SSL_CTX_SET_NUM_TICKETS */ + } + + if (config.d_sessionTimeout > 0) { + SSL_CTX_set_timeout(ctx.get(), config.d_sessionTimeout); + } + + if (config.d_preferServerCiphers) { + sslOptions |= SSL_OP_CIPHER_SERVER_PREFERENCE; +#ifdef SSL_OP_PRIORITIZE_CHACHA + sslOptions |= SSL_OP_PRIORITIZE_CHACHA; +#endif /* SSL_OP_PRIORITIZE_CHACHA */ + } + + SSL_CTX_set_options(ctx.get(), sslOptions); + if (!libssl_set_min_tls_version(ctx, config.d_minTLSVersion)) { + throw std::runtime_error("Failed to set the minimum version to '" + libssl_tls_version_to_string(config.d_minTLSVersion)); + } + +#ifdef SSL_CTX_set_ecdh_auto + SSL_CTX_set_ecdh_auto(ctx.get(), 1); +#endif + + if (config.d_maxStoredSessions == 0) { + /* disable stored sessions entirely */ + SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_OFF); + } + else { + /* use the internal built-in cache to store sessions */ + SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_SERVER); + SSL_CTX_sess_set_cache_size(ctx.get(), config.d_maxStoredSessions); + } + + /* we need to set this callback to acknowledge the server name sent by the client, + otherwise it will not stored in the session and will not be accessible when the + session is resumed, causing SSL_get_servername to return nullptr */ + SSL_CTX_set_tlsext_servername_callback(ctx.get(), &libssl_server_name_callback); + + std::vector keyTypes; + /* load certificate and private key */ + for (const auto& pair : config.d_certKeyPairs) { + if (SSL_CTX_use_certificate_chain_file(ctx.get(), pair.first.c_str()) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.first); + } + if (SSL_CTX_use_PrivateKey_file(ctx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.second); + } + if (SSL_CTX_check_private_key(ctx.get()) != 1) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("The key from '" + pair.second + "' does not match the certificate from '" + pair.first + "'"); + } + /* store the type of the new key, we might need it later to select the right OCSP stapling response */ + auto keyType = libssl_get_last_key_type(ctx); + if (keyType < 0) { + throw std::runtime_error("The key from '" + pair.second + "' has an unknown type"); + } + keyTypes.push_back(keyType); + } + + if (!config.d_ocspFiles.empty()) { + try { + ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, keyTypes); + } + catch(const std::exception& e) { + throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what())); + } + } + + if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) { + throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers); + } + +#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES + if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) { + throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13); + } +#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */ + + return ctx; +} + +#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK +static void libssl_key_log_file_callback(const SSL* ssl, const char* line) +{ + SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl); + if (sslCtx == nullptr) { + return; + } + + auto fp = reinterpret_cast(SSL_CTX_get_ex_data(sslCtx, s_keyLogIndex)); + if (fp == nullptr) { + return; + } + + fprintf(fp, "%s\n", line); + fflush(fp); +} +#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */ + +std::unique_ptr libssl_set_key_log_file(std::unique_ptr& ctx, const std::string& logFile) +{ +#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK + auto fp = std::unique_ptr(fopen(logFile.c_str(), "a"), fclose); + if (!fp) { + throw std::runtime_error("Error opening TLS log file '" + logFile + "'"); + } + + SSL_CTX_set_ex_data(ctx.get(), s_keyLogIndex, fp.get()); + SSL_CTX_set_keylog_callback(ctx.get(), &libssl_key_log_file_callback); + + return fp; +#else + return std::unique_ptr(nullptr, fclose); +#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */ +} + +#endif /* HAVE_LIBSSL */ diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc new file mode 100644 index 0000000000..e308e2791b --- /dev/null +++ b/pdns/tcpiohandler.cc @@ -0,0 +1,926 @@ + +#include "config.h" +#include "dolog.hh" +#include "iputils.hh" +#include "lock.hh" +#include "tcpiohandler.hh" + +#ifdef HAVE_LIBSODIUM +#include +#endif /* HAVE_LIBSODIUM */ + +#ifdef HAVE_DNS_OVER_TLS +#ifdef HAVE_LIBSSL +#include +#include +#include +#include + +#include "libssl.hh" + +class OpenSSLFrontendContext +{ +public: + OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys) + { + registerOpenSSLUser(); + + d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses); + if (!d_tlsCtx) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort()); + } + } + + void cleanup() + { + d_tlsCtx.reset(); + + unregisterOpenSSLUser(); + } + + OpenSSLTLSTicketKeysRing d_ticketKeys; + std::map d_ocspResponses; + std::unique_ptr d_tlsCtx{nullptr, SSL_CTX_free}; + std::unique_ptr d_keyLogFile{nullptr, fclose}; +}; + +class OpenSSLTLSConnection: public TLSConnection +{ +public: + OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr feContext): d_feContext(feContext), d_conn(std::unique_ptr(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) + { + d_socket = socket; + + if (!s_initTLSConnIndex.test_and_set()) { + /* not initialized yet */ + s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + if (s_tlsConnIndex == -1) { + throw std::runtime_error("Error getting an index for TLS connection data"); + } + } + + if (!d_conn) { + vinfolog("Error creating TLS object"); + if (g_verbose) { + ERR_print_errors_fp(stderr); + } + throw std::runtime_error("Error creating TLS object"); + } + + if (!SSL_set_fd(d_conn.get(), d_socket)) { + throw std::runtime_error("Error assigning socket"); + } + + SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this); + } + + IOState convertIORequestToIOState(int res) const + { + int error = SSL_get_error(d_conn.get(), res); + if (error == SSL_ERROR_WANT_READ) { + return IOState::NeedRead; + } + else if (error == SSL_ERROR_WANT_WRITE) { + return IOState::NeedWrite; + } + else if (error == SSL_ERROR_SYSCALL) { + throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno))); + } + else { + throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error)); + } + } + + void handleIORequest(int res, unsigned int timeout) + { + auto state = convertIORequestToIOState(res); + if (state == IOState::NeedRead) { + res = waitForData(d_socket, timeout); + if (res == 0) { + throw std::runtime_error("Timeout while reading from TLS connection"); + } + else if (res < 0) { + throw std::runtime_error("Error waiting to read from TLS connection"); + } + } + else if (state == IOState::NeedWrite) { + res = waitForRWData(d_socket, false, timeout, 0); + if (res == 0) { + throw std::runtime_error("Timeout while writing to TLS connection"); + } + else if (res < 0) { + throw std::runtime_error("Error waiting to write to TLS connection"); + } + } + } + + IOState tryHandshake() override + { + int res = SSL_accept(d_conn.get()); + if (res == 1) { + return IOState::Done; + } + else if (res < 0) { + return convertIORequestToIOState(res); + } + + throw std::runtime_error("Error accepting TLS connection"); + } + + void doHandshake() override + { + int res = 0; + do { + res = SSL_accept(d_conn.get()); + if (res < 0) { + handleIORequest(res, d_timeout); + } + } + while (res < 0); + + if (res != 1) { + throw std::runtime_error("Error accepting TLS connection"); + } + } + + IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override + { + do { + int res = SSL_write(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toWrite - pos)); + if (res <= 0) { + return convertIORequestToIOState(res); + } + else { + pos += static_cast(res); + } + } + while (pos < toWrite); + return IOState::Done; + } + + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override + { + do { + int res = SSL_read(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toRead - pos)); + if (res <= 0) { + return convertIORequestToIOState(res); + } + else { + pos += static_cast(res); + } + } + while (pos < toRead); + return IOState::Done; + } + + size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override + { + size_t got = 0; + time_t start = 0; + unsigned int remainingTime = totalTimeout; + if (totalTimeout) { + start = time(nullptr); + } + + do { + int res = SSL_read(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); + if (res <= 0) { + handleIORequest(res, readTimeout); + } + else { + got += static_cast(res); + } + + if (totalTimeout) { + time_t now = time(nullptr); + unsigned int elapsed = now - start; + if (now < start || elapsed >= remainingTime) { + throw runtime_error("Timeout while reading data"); + } + start = now; + remainingTime -= elapsed; + } + } + while (got < bufferSize); + + return got; + } + + size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override + { + size_t got = 0; + do { + int res = SSL_write(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); + if (res <= 0) { + handleIORequest(res, writeTimeout); + } + else { + got += static_cast(res); + } + } + while (got < bufferSize); + + return got; + } + + bool hasBufferedData() const override + { + if (d_conn) { + return SSL_pending(d_conn.get()) > 0; + } + + return false; + } + + void close() override + { + if (d_conn) { + SSL_shutdown(d_conn.get()); + } + } + + std::string getServerNameIndication() const override + { + if (d_conn) { + const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name); + if (value) { + return std::string(value); + } + } + return std::string(); + } + + LibsslTLSVersion getTLSVersion() const override + { + auto proto = SSL_version(d_conn.get()); + switch (proto) { + case TLS1_VERSION: + return LibsslTLSVersion::TLS10; + case TLS1_1_VERSION: + return LibsslTLSVersion::TLS11; + case TLS1_2_VERSION: + return LibsslTLSVersion::TLS12; +#ifdef TLS1_3_VERSION + case TLS1_3_VERSION: + return LibsslTLSVersion::TLS13; +#endif /* TLS1_3_VERSION */ + default: + return LibsslTLSVersion::Unknown; + } + } + + bool hasSessionBeenResumed() const override + { + if (d_conn) { + return SSL_session_reused(d_conn.get()) != 0; + } + return false; + } + + static int s_tlsConnIndex; + +private: + static std::atomic_flag s_initTLSConnIndex; + + std::shared_ptr d_feContext; + std::unique_ptr d_conn; + unsigned int d_timeout; +}; + +std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT; +int OpenSSLTLSConnection::s_tlsConnIndex = -1; + +class OpenSSLTLSIOCtx: public TLSCtx +{ +public: + OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared(fe.d_addr, fe.d_tlsConfig)) + { + d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; + + if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) { + /* use our own ticket keys handler so we can rotate them */ + SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); + libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get()); + } + + if (!d_feContext->d_ocspResponses.empty()) { + SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); + SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses); + } + + libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); + + if (!fe.d_tlsConfig.d_keyLogFile.empty()) { + d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile); + } + + try { + if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { + handleTicketsKeyRotation(time(nullptr)); + } + else { + OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); + } + } + catch (const std::exception& e) { + throw; + } + } + + ~OpenSSLTLSIOCtx() override + { + } + + static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) + { + OpenSSLFrontendContext* ctx = reinterpret_cast(libssl_get_ticket_key_callback_data(s)); + if (ctx == nullptr) { + return -1; + } + + int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc); + if (enc == 0) { + if (ret == 0 || ret == 2) { + OpenSSLTLSConnection* conn = reinterpret_cast(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex)); + if (conn) { + if (ret == 0) { + conn->setUnknownTicketKey(); + } + else if (ret == 2) { + conn->setResumedFromInactiveTicketKey(); + } + } + } + } + + return ret; + } + + static int ocspStaplingCb(SSL* ssl, void* arg) + { + if (ssl == nullptr || arg == nullptr) { + return SSL_TLSEXT_ERR_NOACK; + } + const auto ocspMap = reinterpret_cast*>(arg); + return libssl_ocsp_stapling_callback(ssl, *ocspMap); + } + + std::unique_ptr getConnection(int socket, unsigned int timeout, time_t now) override + { + handleTicketsKeyRotation(now); + + return std::unique_ptr(new OpenSSLTLSConnection(socket, timeout, d_feContext)); + } + + void rotateTicketsKey(time_t now) override + { + d_feContext->d_ticketKeys.rotateTicketsKey(now); + + if (d_ticketsKeyRotationDelay > 0) { + d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; + } + } + + void loadTicketsKeys(const std::string& keyFile) override final + { + d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); + + if (d_ticketsKeyRotationDelay > 0) { + d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; + } + } + + size_t getTicketsKeysCount() override + { + return d_feContext->d_ticketKeys.getKeysCount(); + } + +private: + std::shared_ptr d_feContext; +}; + +#endif /* HAVE_LIBSSL */ + +#ifdef HAVE_GNUTLS +#include +#include + +static void safe_memory_lock(void* data, size_t size) +{ +#ifdef HAVE_LIBSODIUM + sodium_mlock(data, size); +#endif +} + +static void safe_memory_release(void* data, size_t size) +{ +#ifdef HAVE_LIBSODIUM + sodium_munlock(data, size); +#elif defined(HAVE_EXPLICIT_BZERO) + explicit_bzero(data, size); +#elif defined(HAVE_EXPLICIT_MEMSET) + explicit_memset(data, 0, size); +#elif defined(HAVE_GNUTLS_MEMSET) + gnutls_memset(data, 0, size); +#else + /* shamelessly taken from Dovecot's src/lib/safe-memset.c */ + volatile unsigned int volatile_zero_idx = 0; + volatile unsigned char *p = reinterpret_cast(data); + + if (size == 0) + return; + + do { + memset(data, 0, size); + } while (p[volatile_zero_idx] != 0); +#endif +} + +class GnuTLSTicketsKey +{ +public: + GnuTLSTicketsKey() + { + if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error generating tickets key for TLS context"); + } + + safe_memory_lock(d_key.data, d_key.size); + } + + GnuTLSTicketsKey(const std::string& keyFile) + { + /* to be sure we are loading the correct amount of data, which + may change between versions, let's generate a correct key first */ + if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context"); + } + + safe_memory_lock(d_key.data, d_key.size); + + try { + ifstream file(keyFile); + file.read(reinterpret_cast(d_key.data), d_key.size); + + if (file.fail()) { + file.close(); + throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile); + } + + file.close(); + } + catch (const std::exception& e) { + safe_memory_release(d_key.data, d_key.size); + gnutls_free(d_key.data); + d_key.data = nullptr; + throw; + } + } + + ~GnuTLSTicketsKey() + { + if (d_key.data != nullptr && d_key.size > 0) { + safe_memory_release(d_key.data, d_key.size); + } + gnutls_free(d_key.data); + d_key.data = nullptr; + } + const gnutls_datum_t& getKey() const + { + return d_key; + } + +private: + gnutls_datum_t d_key{nullptr, 0}; +}; + +class GnuTLSConnection: public TLSConnection +{ +public: + + GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr& ticketsKey, bool enableTickets): d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey) + { + unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK; +#ifdef GNUTLS_NO_SIGNAL + sslOptions |= GNUTLS_NO_SIGNAL; +#endif + + d_socket = socket; + + gnutls_session_t conn; + if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error creating TLS connection"); + } + + d_conn = std::unique_ptr(conn, gnutls_deinit); + conn = nullptr; + + if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error setting certificate and key to TLS connection"); + } + + if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error setting ciphers to TLS connection"); + } + + if (enableTickets && d_ticketsKey) { + const gnutls_datum_t& key = d_ticketsKey->getKey(); + if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error setting the tickets key to TLS connection"); + } + } + + gnutls_transport_set_int(d_conn.get(), d_socket); + + /* timeouts are in milliseconds */ + gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000); + gnutls_record_set_timeout(d_conn.get(), timeout * 1000); + } + + void doHandshake() override + { + int ret = 0; + do { + ret = gnutls_handshake(d_conn.get()); + if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { + throw std::runtime_error("Error accepting a new connection"); + } + } + while (ret < 0 && ret == GNUTLS_E_INTERRUPTED); + } + + IOState tryHandshake() override + { + int ret = 0; + + do { + ret = gnutls_handshake(d_conn.get()); + if (ret == GNUTLS_E_SUCCESS) { + return IOState::Done; + } + else if (ret == GNUTLS_E_AGAIN) { + return IOState::NeedRead; + } + else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { + throw std::runtime_error("Error accepting a new connection"); + } + } while (ret == GNUTLS_E_INTERRUPTED); + + throw std::runtime_error("Error accepting a new connection"); + } + + IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override + { + do { + ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toWrite - pos); + if (res == 0) { + throw std::runtime_error("Error writing to TLS connection"); + } + else if (res > 0) { + pos += static_cast(res); + } + else if (res < 0) { + if (gnutls_error_is_fatal(res)) { + throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); + } + else if (res == GNUTLS_E_AGAIN) { + return IOState::NeedWrite; + } + warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + } + } + while (pos < toWrite); + return IOState::Done; + } + + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override + { + do { + ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toRead - pos); + if (res == 0) { + throw std::runtime_error("Error reading from TLS connection"); + } + else if (res > 0) { + pos += static_cast(res); + } + else if (res < 0) { + if (gnutls_error_is_fatal(res)) { + throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); + } + else if (res == GNUTLS_E_AGAIN) { + return IOState::NeedRead; + } + warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + } + } + while (pos < toRead); + return IOState::Done; + } + + size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override + { + size_t got = 0; + time_t start = 0; + unsigned int remainingTime = totalTimeout; + if (totalTimeout) { + start = time(nullptr); + } + + do { + ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); + if (res == 0) { + throw std::runtime_error("Error reading from TLS connection"); + } + else if (res > 0) { + got += static_cast(res); + } + else if (res < 0) { + if (gnutls_error_is_fatal(res)) { + throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); + } + else if (res == GNUTLS_E_AGAIN) { + int result = waitForData(d_socket, readTimeout); + if (result <= 0) { + throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result)); + } + } + else { + vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res)); + } + } + + if (totalTimeout) { + time_t now = time(nullptr); + unsigned int elapsed = now - start; + if (now < start || elapsed >= remainingTime) { + throw runtime_error("Timeout while reading data"); + } + start = now; + remainingTime -= elapsed; + } + } + while (got < bufferSize); + + return got; + } + + size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override + { + size_t got = 0; + + do { + ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); + if (res == 0) { + throw std::runtime_error("Error writing to TLS connection"); + } + else if (res > 0) { + got += static_cast(res); + } + else if (res < 0) { + if (gnutls_error_is_fatal(res)) { + throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); + } + else if (res == GNUTLS_E_AGAIN) { + int result = waitForRWData(d_socket, false, writeTimeout, 0); + if (result <= 0) { + throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result)); + } + } + else { + vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + } + } + } + while (got < bufferSize); + + return got; + } + + bool hasBufferedData() const override + { + if (d_conn) { + return gnutls_record_check_pending(d_conn.get()) > 0; + } + + return false; + } + + std::string getServerNameIndication() const override + { + if (d_conn) { + unsigned int type; + size_t name_len = 256; + std::string sni; + sni.resize(name_len); + + int res = gnutls_server_name_get(d_conn.get(), const_cast(sni.c_str()), &name_len, &type, 0); + if (res == GNUTLS_E_SUCCESS) { + sni.resize(name_len); + return sni; + } + } + return std::string(); + } + + LibsslTLSVersion getTLSVersion() const override + { + auto proto = gnutls_protocol_get_version(d_conn.get()); + switch (proto) { + case GNUTLS_TLS1_0: + return LibsslTLSVersion::TLS10; + case GNUTLS_TLS1_1: + return LibsslTLSVersion::TLS11; + case GNUTLS_TLS1_2: + return LibsslTLSVersion::TLS12; +#if GNUTLS_VERSION_NUMBER >= 0x030603 + case GNUTLS_TLS1_3: + return LibsslTLSVersion::TLS13; +#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */ + default: + return LibsslTLSVersion::Unknown; + } + } + + bool hasSessionBeenResumed() const override + { + if (d_conn) { + return gnutls_session_is_resumed(d_conn.get()) != 0; + } + return false; + } + + void close() override + { + if (d_conn) { + gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR); + } + } + +private: + std::unique_ptr d_conn; + std::shared_ptr d_ticketsKey; +}; + +class GnuTLSIOCtx: public TLSCtx +{ +public: + GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets) + { + int rc = 0; + d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; + + gnutls_certificate_credentials_t creds; + rc = gnutls_certificate_allocate_credentials(&creds); + if (rc != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); + } + + d_creds = std::unique_ptr(creds, gnutls_certificate_free_credentials); + creds = nullptr; + + for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) { + rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM); + if (rc != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); + } + } + + size_t count = 0; + for (const auto& file : fe.d_tlsConfig.d_ocspFiles) { + rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count); + if (rc != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); + } + ++count; + } + +#if GNUTLS_VERSION_NUMBER >= 0x030600 + rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH); + if (rc != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); + } +#endif + + rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr); + if (rc != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort()); + } + + try { + if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { + handleTicketsKeyRotation(time(nullptr)); + } + else { + GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); + } + } + catch(const std::runtime_error& e) { + throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what()); + } + } + + virtual ~GnuTLSIOCtx() override + { + d_creds.reset(); + + if (d_priorityCache) { + gnutls_priority_deinit(d_priorityCache); + } + } + + std::unique_ptr getConnection(int socket, unsigned int timeout, time_t now) override + { + handleTicketsKeyRotation(now); + + std::shared_ptr ticketsKey; + { + ReadLock rl(&d_lock); + ticketsKey = d_ticketsKey; + } + + return std::unique_ptr(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets)); + } + + void rotateTicketsKey(time_t now) override + { + if (!d_enableTickets) { + return; + } + + auto newKey = std::make_shared(); + + { + WriteLock wl(&d_lock); + d_ticketsKey = newKey; + } + + if (d_ticketsKeyRotationDelay > 0) { + d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; + } + } + + void loadTicketsKeys(const std::string& file) override final + { + if (!d_enableTickets) { + return; + } + + auto newKey = std::make_shared(file); + { + WriteLock wl(&d_lock); + d_ticketsKey = newKey; + } + + if (d_ticketsKeyRotationDelay > 0) { + d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; + } + } + + size_t getTicketsKeysCount() override + { + ReadLock rl(&d_lock); + return d_ticketsKey != nullptr ? 1 : 0; + } + +private: + std::unique_ptr d_creds; + gnutls_priority_t d_priorityCache{nullptr}; + std::shared_ptr d_ticketsKey{nullptr}; + ReadWriteLock d_lock; + bool d_enableTickets{true}; +}; + +#endif /* HAVE_GNUTLS */ + +#endif /* HAVE_DNS_OVER_TLS */ + +bool TLSFrontend::setupTLS() +{ +#ifdef HAVE_DNS_OVER_TLS + /* get the "best" available provider */ + if (!d_provider.empty()) { +#ifdef HAVE_GNUTLS + if (d_provider == "gnutls") { + d_ctx = std::make_shared(*this); + return true; + } +#endif /* HAVE_GNUTLS */ +#ifdef HAVE_LIBSSL + if (d_provider == "openssl") { + d_ctx = std::make_shared(*this); + return true; + } +#endif /* HAVE_LIBSSL */ + } +#ifdef HAVE_LIBSSL + d_ctx = std::make_shared(*this); +#else /* HAVE_LIBSSL */ +#ifdef HAVE_GNUTLS + d_ctx = std::make_shared(*this); +#endif /* HAVE_GNUTLS */ +#endif /* HAVE_LIBSSL */ + +#endif /* HAVE_DNS_OVER_TLS */ + return true; +}