PDNS_CHECK_LIBCRYPTO_ECDSA
PDNS_CHECK_LIBCRYPTO_EDDSA
+AM_CONDITIONAL([HAVE_GNUTLS], [false])
+AM_CONDITIONAL([HAVE_LIBSSL], [false])
+
+PDNS_ENABLE_DNS_OVER_TLS
+AS_IF([test "x$enable_dns_over_tls" != "xno"], [
+ PDNS_WITH_LIBSSL
+ PDNS_WITH_GNUTLS
+ AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [
+ AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
+ ])
+])
+
PDNS_CHECK_RAGEL([pdns/dnslabeltext.cc], [www.powerdns.com])
PDNS_CHECK_CLOCK_GETTIME
-AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_TLS], [
+AC_DEFUN([PDNS_ENABLE_DNS_OVER_TLS], [
AC_MSG_CHECKING([whether to enable DNS over TLS support])
AC_ARG_ENABLE([dns-over-tls],
AS_HELP_STRING([--enable-dns-over-tls], [enable DNS over TLS support (requires GnuTLS or OpenSSL) @<:@default=no@:>@]),
-AC_DEFUN([DNSDIST_WITH_GNUTLS], [
+AC_DEFUN([PDNS_WITH_GNUTLS], [
AC_MSG_CHECKING([whether we will be linking in GnuTLS])
HAVE_GNUTLS=0
AC_ARG_WITH([gnutls],
-AC_DEFUN([DNSDIST_WITH_LIBSSL], [
+AC_DEFUN([PDNS_WITH_LIBSSL], [
AC_MSG_CHECKING([whether we will be linking in OpenSSL libssl])
HAVE_LIBSSL=0
AC_ARG_WITH([libssl],
dnsrecords.cc \
dnswriter.cc dnswriter.hh \
ednssubnet.cc iputils.cc \
+ libssl.cc libssl.hh \
logger.cc \
misc.cc misc.hh \
nsecrecords.cc \
sstuff.hh \
statbag.cc \
svc-records.cc svc-records.hh \
+ tcpiohandler.cc tcpiohandler.hh \
unix_utility.cc
+sdig_CPPFLAGS = $(AM_CPPFLAGS)
sdig_LDADD = $(LIBCRYPTO_LIBS)
sdig_LDFLAGS = $(AM_LDFLAGS) $(LIBCRYPTO_LDFLAGS)
sdig_LDADD += $(LIBCURL)
endif
+if HAVE_DNS_OVER_TLS
+
+if HAVE_GNUTLS
+sdig_CPPFLAGS += $(GNUTLS_CFLAGS)
+sdig_LDADD += -lgnutls
+endif
+
+if HAVE_LIBSSL
+sdig_CPPFLAGS += $(LIBSSL_CFLAGS)
+sdig_LDADD += $(LIBSSL_LIBS)
+endif
+
+if LIBSODIUM
+sdig_CPPFLAGS +=$(LIBSODIUM_CFLAGS)
+sdig_LDADD += $(LIBSODIUM_LIBS)
+endif
+
+endif
+
calidns_SOURCES = \
base32.cc \
base64.cc base64.hh \
PDNS_CHECK_LIBCRYPTO
-DNSDIST_ENABLE_DNS_OVER_TLS
+PDNS_ENABLE_DNS_OVER_TLS
DNSDIST_ENABLE_DNS_OVER_HTTPS
AS_IF([test "x$enable_dns_over_tls" != "xno" -o "x$enable_dns_over_https" != "xno"], [
- DNSDIST_WITH_LIBSSL
+ PDNS_WITH_LIBSSL
])
AS_IF([test "x$enable_dns_over_tls" != "xno"], [
- DNSDIST_WITH_GNUTLS
+ PDNS_WITH_GNUTLS
AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [
AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
+++ /dev/null
-
-#include "config.h"
-#include "libssl.hh"
-
-#ifdef HAVE_LIBSSL
-
-#include <atomic>
-#include <fstream>
-#include <cstring>
-#include <mutex>
-#include <pthread.h>
-
-#include <openssl/conf.h>
-#include <openssl/err.h>
-#include <openssl/ocsp.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-
-#ifdef HAVE_LIBSODIUM
-#include <sodium.h>
-#endif /* HAVE_LIBSODIUM */
-
-#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL)
-/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
-
-#include "lock.hh"
-static std::vector<std::mutex> openssllocks;
-
-extern "C" {
-static void openssl_pthreads_locking_callback(int mode, int type, const char *file, int line)
-{
- if (mode & CRYPTO_LOCK) {
- openssllocks.at(type).lock();
-
- } else {
- openssllocks.at(type).unlock();
- }
-}
-
-static unsigned long openssl_pthreads_id_callback()
-{
- return (unsigned long)pthread_self();
-}
-}
-
-static void openssl_thread_setup()
-{
- openssllocks = std::vector<std::mutex>(CRYPTO_num_locks());
- CRYPTO_set_id_callback(&openssl_pthreads_id_callback);
- CRYPTO_set_locking_callback(&openssl_pthreads_locking_callback);
-}
-
-static void openssl_thread_cleanup()
-{
- CRYPTO_set_locking_callback(nullptr);
- openssllocks.clear();
-}
-
-#endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) */
-
-static std::atomic<uint64_t> s_users;
-static int s_ticketsKeyIndex{-1};
-static int s_countersIndex{-1};
-static int s_keyLogIndex{-1};
-
-void registerOpenSSLUser()
-{
- if (s_users.fetch_add(1) == 0) {
-#ifdef HAVE_OPENSSL_INIT_CRYPTO
- /* load the default configuration file (or one specified via OPENSSL_CONF),
- which can then be used to load engines */
- OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CONFIG, nullptr);
-#endif
-
-#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL))
- SSL_load_error_strings();
- OpenSSL_add_ssl_algorithms();
- openssl_thread_setup();
-#endif
- s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
-
- if (s_ticketsKeyIndex == -1) {
- throw std::runtime_error("Error getting an index for tickets key");
- }
-
- s_countersIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
-
- if (s_countersIndex == -1) {
- throw std::runtime_error("Error getting an index for counters");
- }
-
- s_keyLogIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
-
- if (s_keyLogIndex == -1) {
- throw std::runtime_error("Error getting an index for TLS key logging");
- }
- }
-}
-
-void unregisterOpenSSLUser()
-{
- if (s_users.fetch_sub(1) == 1) {
-#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL))
- ERR_free_strings();
-
- EVP_cleanup();
-
- CONF_modules_finish();
- CONF_modules_free();
- CONF_modules_unload(1);
-
- CRYPTO_cleanup_all_ex_data();
- openssl_thread_cleanup();
-#endif
- }
-}
-
-void* libssl_get_ticket_key_callback_data(SSL* s)
-{
- SSL_CTX* sslCtx = SSL_get_SSL_CTX(s);
- if (sslCtx == nullptr) {
- return nullptr;
- }
-
- return SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex);
-}
-
-void libssl_set_ticket_key_callback_data(SSL_CTX* ctx, void* data)
-{
- SSL_CTX_set_ex_data(ctx, s_ticketsKeyIndex, data);
-}
-
-int libssl_ticket_key_callback(SSL *s, OpenSSLTLSTicketKeysRing& keyring, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
-{
- if (enc) {
- const auto key = keyring.getEncryptionKey();
- if (key == nullptr) {
- return -1;
- }
-
- return key->encrypt(keyName, iv, ectx, hctx);
- }
-
- bool activeEncryptionKey = false;
-
- const auto key = keyring.getDecryptionKey(keyName, activeEncryptionKey);
- if (key == nullptr) {
- /* we don't know this key, just create a new ticket */
- return 0;
- }
-
- if (key->decrypt(iv, ectx, hctx) == false) {
- return -1;
- }
-
- if (!activeEncryptionKey) {
- /* this key is not active, please encrypt the ticket content with the currently active one */
- return 2;
- }
-
- return 1;
-}
-
-static long libssl_server_name_callback(SSL* ssl, int* al, void* arg)
-{
- (void) al;
- (void) arg;
-
- if (SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name)) {
- return SSL_TLSEXT_ERR_OK;
- }
-
- return SSL_TLSEXT_ERR_NOACK;
-}
-
-static void libssl_info_callback(const SSL *ssl, int where, int ret)
-{
- SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl);
- if (sslCtx == nullptr) {
- return;
- }
-
- TLSErrorCounters* counters = reinterpret_cast<TLSErrorCounters*>(SSL_CTX_get_ex_data(sslCtx, s_countersIndex));
- if (counters == nullptr) {
- return;
- }
-
- if (where & SSL_CB_ALERT) {
- const long lastError = ERR_peek_last_error();
- switch (ERR_GET_REASON(lastError)) {
-#ifdef SSL_R_DH_KEY_TOO_SMALL
- case SSL_R_DH_KEY_TOO_SMALL:
- ++counters->d_dhKeyTooSmall;
- break;
-#endif /* SSL_R_DH_KEY_TOO_SMALL */
- case SSL_R_NO_SHARED_CIPHER:
- ++counters->d_noSharedCipher;
- break;
- case SSL_R_UNKNOWN_PROTOCOL:
- ++counters->d_unknownProtocol;
- break;
- case SSL_R_UNSUPPORTED_PROTOCOL:
-#ifdef SSL_R_VERSION_TOO_LOW
- case SSL_R_VERSION_TOO_LOW:
-#endif /* SSL_R_VERSION_TOO_LOW */
- ++counters->d_unsupportedProtocol;
- break;
- case SSL_R_INAPPROPRIATE_FALLBACK:
- ++counters->d_inappropriateFallBack;
- break;
- case SSL_R_UNKNOWN_CIPHER_TYPE:
- ++counters->d_unknownCipherType;
- break;
- case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE:
- ++counters->d_unknownKeyExchangeType;
- break;
- case SSL_R_UNSUPPORTED_ELLIPTIC_CURVE:
- ++counters->d_unsupportedEC;
- break;
- default:
- break;
- }
- }
-}
-
-void libssl_set_error_counters_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, TLSErrorCounters* counters)
-{
- SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters);
- SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback);
-}
-
-int libssl_ocsp_stapling_callback(SSL* ssl, const std::map<int, std::string>& ocspMap)
-{
- auto pkey = SSL_get_privatekey(ssl);
- if (pkey == nullptr) {
- return SSL_TLSEXT_ERR_NOACK;
- }
-
- /* look for an OCSP response for the corresponding private key type (RSA, ECDSA..) */
- const auto& data = ocspMap.find(EVP_PKEY_base_id(pkey));
- if (data == ocspMap.end()) {
- return SSL_TLSEXT_ERR_NOACK;
- }
-
- /* we need to allocate a copy because OpenSSL will free the pointer passed to SSL_set_tlsext_status_ocsp_resp() */
- void* copy = OPENSSL_malloc(data->second.size());
- if (copy == nullptr) {
- return SSL_TLSEXT_ERR_NOACK;
- }
-
- memcpy(copy, data->second.data(), data->second.size());
- SSL_set_tlsext_status_ocsp_resp(ssl, copy, data->second.size());
- return SSL_TLSEXT_ERR_OK;
-}
-
-static bool libssl_validate_ocsp_response(const std::string& response)
-{
- auto responsePtr = reinterpret_cast<const unsigned char *>(response.data());
- std::unique_ptr<OCSP_RESPONSE, void(*)(OCSP_RESPONSE*)> resp(d2i_OCSP_RESPONSE(nullptr, &responsePtr, response.size()), OCSP_RESPONSE_free);
- if (resp == nullptr) {
- throw std::runtime_error("Unable to parse OCSP response");
- }
-
- int status = OCSP_response_status(resp.get());
- if (status != OCSP_RESPONSE_STATUS_SUCCESSFUL) {
- throw std::runtime_error("OCSP response status is not successful: " + std::to_string(status));
- }
-
- std::unique_ptr<OCSP_BASICRESP, void(*)(OCSP_BASICRESP*)> basic(OCSP_response_get1_basic(resp.get()), OCSP_BASICRESP_free);
- if (basic == nullptr) {
- throw std::runtime_error("Error getting a basic OCSP response");
- }
-
- if (OCSP_resp_count(basic.get()) != 1) {
- throw std::runtime_error("More than one single response in an OCSP basic response");
- }
-
- auto singleResponse = OCSP_resp_get0(basic.get(), 0);
- if (singleResponse == nullptr) {
- throw std::runtime_error("Error getting a single response from the basic OCSP response");
- }
-
- int reason;
- ASN1_GENERALIZEDTIME* revTime = nullptr;
- ASN1_GENERALIZEDTIME* thisUpdate = nullptr;
- ASN1_GENERALIZEDTIME* nextUpdate = nullptr;
-
- auto singleResponseStatus = OCSP_single_get0_status(singleResponse, &reason, &revTime, &thisUpdate, &nextUpdate);
- if (singleResponseStatus != V_OCSP_CERTSTATUS_GOOD) {
- throw std::runtime_error("Invalid status for OCSP single response (" + std::to_string(singleResponseStatus) + ")");
- }
- if (thisUpdate == nullptr || nextUpdate == nullptr) {
- throw std::runtime_error("Error getting validity of OCSP single response");
- }
-
- auto validityResult = OCSP_check_validity(thisUpdate, nextUpdate, /* 5 minutes of leeway */ 5 * 60, -1);
- if (validityResult == 0) {
- throw std::runtime_error("OCSP single response is not yet, or no longer, valid");
- }
-
- return true;
-}
-
-std::map<int, std::string> libssl_load_ocsp_responses(const std::vector<std::string>& ocspFiles, std::vector<int> keyTypes)
-{
- std::map<int, std::string> ocspResponses;
-
- if (ocspFiles.size() > keyTypes.size()) {
- throw std::runtime_error("More OCSP files than certificates and keys loaded!");
- }
-
- size_t count = 0;
- for (const auto& filename : ocspFiles) {
- std::ifstream file(filename, std::ios::binary);
- std::string content;
- while(file) {
- char buffer[4096];
- file.read(buffer, sizeof(buffer));
- if (file.bad()) {
- file.close();
- throw std::runtime_error("Unable to load OCSP response from '" + filename + "'");
- }
- content.append(buffer, file.gcount());
- }
- file.close();
-
- try {
- libssl_validate_ocsp_response(content);
- ocspResponses.insert({keyTypes.at(count), std::move(content)});
- }
- catch (const std::exception& e) {
- throw std::runtime_error("Error checking the validity of OCSP response from '" + filename + "': " + e.what());
- }
- ++count;
- }
-
- return ocspResponses;
-}
-
-int libssl_get_last_key_type(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx)
-{
-#ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY
- auto pkey = SSL_CTX_get0_privatekey(ctx.get());
-#else
- auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(ctx.get()), SSL_free);
- if (!temp) {
- return -1;
- }
- auto pkey = SSL_get_privatekey(temp.get());
-#endif
-
- if (!pkey) {
- return -1;
- }
-
- return EVP_PKEY_base_id(pkey);
-}
-
-#ifdef HAVE_OCSP_BASIC_SIGN
-bool libssl_generate_ocsp_response(const std::string& certFile, const std::string& caCert, const std::string& caKey, const std::string& outFile, int ndays, int nmin)
-{
- const EVP_MD* rmd = EVP_sha256();
-
- auto fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(certFile.c_str(), "r"), fclose);
- if (!fp) {
- throw std::runtime_error("Unable to open '" + certFile + "' when loading the certificate to generate an OCSP response");
- }
- auto cert = std::unique_ptr<X509, void(*)(X509*)>(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free);
-
- fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(caCert.c_str(), "r"), fclose);
- if (!fp) {
- throw std::runtime_error("Unable to open '" + caCert + "' when loading the issuer certificate to generate an OCSP response");
- }
- auto issuer = std::unique_ptr<X509, void(*)(X509*)>(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free);
- fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(caKey.c_str(), "r"), fclose);
- if (!fp) {
- throw std::runtime_error("Unable to open '" + caKey + "' when loading the issuer key to generate an OCSP response");
- }
- auto issuerKey = std::unique_ptr<EVP_PKEY, void(*)(EVP_PKEY*)>(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr), EVP_PKEY_free);
- fp.reset();
-
- auto bs = std::unique_ptr<OCSP_BASICRESP, void(*)(OCSP_BASICRESP*)>(OCSP_BASICRESP_new(), OCSP_BASICRESP_free);
- auto thisupd = std::unique_ptr<ASN1_TIME, void(*)(ASN1_TIME*)>(X509_gmtime_adj(nullptr, 0), ASN1_TIME_free);
- auto nextupd = std::unique_ptr<ASN1_TIME, void(*)(ASN1_TIME*)>(X509_time_adj_ex(nullptr, ndays, nmin * 60, nullptr), ASN1_TIME_free);
-
- auto cid = std::unique_ptr<OCSP_CERTID, void(*)(OCSP_CERTID*)>(OCSP_cert_to_id(rmd, cert.get(), issuer.get()), OCSP_CERTID_free);
- OCSP_basic_add1_status(bs.get(), cid.get(), V_OCSP_CERTSTATUS_GOOD, 0, nullptr, thisupd.get(), nextupd.get());
-
- if (OCSP_basic_sign(bs.get(), issuer.get(), issuerKey.get(), rmd, nullptr, OCSP_NOCERTS) != 1) {
- throw std::runtime_error("Error while signing the OCSP response");
- }
-
- auto resp = std::unique_ptr<OCSP_RESPONSE, void(*)(OCSP_RESPONSE*)>(OCSP_response_create(OCSP_RESPONSE_STATUS_SUCCESSFUL, bs.get()), OCSP_RESPONSE_free);
- auto bio = std::unique_ptr<BIO, void(*)(BIO*)>(BIO_new_file(outFile.c_str(), "wb"), BIO_vfree);
- if (!bio) {
- throw std::runtime_error("Error opening file for writing the OCSP response");
- }
-
- // i2d_OCSP_RESPONSE_bio(bio.get(), resp.get()) is unusable from C++ because of an invalid cast
- ASN1_i2d_bio((i2d_of_void*)i2d_OCSP_RESPONSE, bio.get(), (unsigned char*)resp.get());
-
- return true;
-}
-#endif /* HAVE_OCSP_BASIC_SIGN */
-
-LibsslTLSVersion libssl_tls_version_from_string(const std::string& str)
-{
- if (str == "tls1.0") {
- return LibsslTLSVersion::TLS10;
- }
- if (str == "tls1.1") {
- return LibsslTLSVersion::TLS11;
- }
- if (str == "tls1.2") {
- return LibsslTLSVersion::TLS12;
- }
- if (str == "tls1.3") {
- return LibsslTLSVersion::TLS13;
- }
- throw std::runtime_error("Unknown TLS version '" + str);
-}
-
-const std::string& libssl_tls_version_to_string(LibsslTLSVersion version)
-{
- static const std::map<LibsslTLSVersion, std::string> versions = {
- { LibsslTLSVersion::TLS10, "tls1.0" },
- { LibsslTLSVersion::TLS11, "tls1.1" },
- { LibsslTLSVersion::TLS12, "tls1.2" },
- { LibsslTLSVersion::TLS13, "tls1.3" }
- };
-
- const auto& it = versions.find(version);
- if (it == versions.end()) {
- throw std::runtime_error("Unknown TLS version (" + std::to_string((int)version) + ")");
- }
- return it->second;
-}
-
-bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, LibsslTLSVersion version)
-{
-#if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version)
- /* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated
- Warning: SSL_CTX_set_min_proto_version is a function-like macro in OpenSSL */
- int vers;
- switch(version) {
- case LibsslTLSVersion::TLS10:
- vers = TLS1_VERSION;
- break;
- case LibsslTLSVersion::TLS11:
- vers = TLS1_1_VERSION;
- break;
- case LibsslTLSVersion::TLS12:
- vers = TLS1_2_VERSION;
- break;
- case LibsslTLSVersion::TLS13:
-#ifdef TLS1_3_VERSION
- vers = TLS1_3_VERSION;
-#else
- return false;
-#endif /* TLS1_3_VERSION */
- break;
- default:
- return false;
- }
-
- if (SSL_CTX_set_min_proto_version(ctx.get(), vers) != 1) {
- return false;
- }
- return true;
-#else
- long vers = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
- switch(version) {
- case LibsslTLSVersion::TLS10:
- break;
- case LibsslTLSVersion::TLS11:
- vers |= SSL_OP_NO_TLSv1;
- break;
- case LibsslTLSVersion::TLS12:
- vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1;
- break;
- case LibsslTLSVersion::TLS13:
- vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2;
- break;
- default:
- return false;
- }
-
- long options = SSL_CTX_get_options(ctx.get());
- SSL_CTX_set_options(ctx.get(), options | vers);
- return true;
-#endif
-}
-
-OpenSSLTLSTicketKeysRing::OpenSSLTLSTicketKeysRing(size_t capacity)
-{
- d_ticketKeys.set_capacity(capacity);
-}
-
-OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing()
-{
-}
-
-void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr<OpenSSLTLSTicketKey> newKey)
-{
- WriteLock wl(&d_lock);
- d_ticketKeys.push_front(newKey);
-}
-
-std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getEncryptionKey()
-{
- ReadLock rl(&d_lock);
- return d_ticketKeys.front();
-}
-
-std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey)
-{
- ReadLock rl(&d_lock);
- for (auto& key : d_ticketKeys) {
- if (key->nameMatches(name)) {
- activeKey = (key == d_ticketKeys.front());
- return key;
- }
- }
- return nullptr;
-}
-
-size_t OpenSSLTLSTicketKeysRing::getKeysCount()
-{
- ReadLock rl(&d_lock);
- return d_ticketKeys.size();
-}
-
-void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile)
-{
- bool keyLoaded = false;
- std::ifstream file(keyFile);
- try {
- do {
- auto newKey = std::make_shared<OpenSSLTLSTicketKey>(file);
- addKey(newKey);
- keyLoaded = true;
- }
- while (!file.fail());
- }
- catch (const std::exception& e) {
- /* if we haven't been able to load at least one key, fail */
- if (!keyLoaded) {
- throw;
- }
- }
-
- file.close();
-}
-
-void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t now)
-{
- auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
- addKey(newKey);
-}
-
-OpenSSLTLSTicketKey::OpenSSLTLSTicketKey()
-{
- if (RAND_bytes(d_name, sizeof(d_name)) != 1) {
- throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
- }
-
- if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) {
- throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
- }
-
- if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) {
- throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
- }
-#ifdef HAVE_LIBSODIUM
- sodium_mlock(d_name, sizeof(d_name));
- sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
- sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
-#endif /* HAVE_LIBSODIUM */
-}
-
-OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(ifstream& file)
-{
- file.read(reinterpret_cast<char*>(d_name), sizeof(d_name));
- file.read(reinterpret_cast<char*>(d_cipherKey), sizeof(d_cipherKey));
- file.read(reinterpret_cast<char*>(d_hmacKey), sizeof(d_hmacKey));
-
- if (file.fail()) {
- throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
- }
-#ifdef HAVE_LIBSODIUM
- sodium_mlock(d_name, sizeof(d_name));
- sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
- sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
-#endif /* HAVE_LIBSODIUM */
-}
-
-OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey()
-{
-#ifdef HAVE_LIBSODIUM
- sodium_munlock(d_name, sizeof(d_name));
- sodium_munlock(d_cipherKey, sizeof(d_cipherKey));
- sodium_munlock(d_hmacKey, sizeof(d_hmacKey));
-#else
- OPENSSL_cleanse(d_name, sizeof(d_name));
- OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey));
- OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey));
-#endif /* HAVE_LIBSODIUM */
-}
-
-bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const
-{
- return (memcmp(d_name, name, sizeof(d_name)) == 0);
-}
-
-int OpenSSLTLSTicketKey::encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
-{
- memcpy(keyName, d_name, sizeof(d_name));
-
- if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) {
- return -1;
- }
-
- if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
- return -1;
- }
-
- if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
- return -1;
- }
-
- return 1;
-}
-
-bool OpenSSLTLSTicketKey::decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
-{
- if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
- return false;
- }
-
- if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
- return false;
- }
-
- return true;
-}
-
-std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> libssl_init_server_context(const TLSConfig& config,
- std::map<int, std::string>& ocspResponses)
-{
- auto ctx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
-
- int sslOptions =
- SSL_OP_NO_SSLv2 |
- SSL_OP_NO_SSLv3 |
- SSL_OP_NO_COMPRESSION |
- SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
- SSL_OP_SINGLE_DH_USE |
- SSL_OP_SINGLE_ECDH_USE;
-
- if (!config.d_enableTickets || config.d_numberOfTicketsKeys == 0) {
- /* for TLS 1.3 this means no stateless tickets, but stateful tickets might still be issued,
- which is something we don't want. */
- sslOptions |= SSL_OP_NO_TICKET;
- /* really disable all tickets */
-#ifdef HAVE_SSL_CTX_SET_NUM_TICKETS
- SSL_CTX_set_num_tickets(ctx.get(), 0);
-#endif /* HAVE_SSL_CTX_SET_NUM_TICKETS */
- }
-
- if (config.d_sessionTimeout > 0) {
- SSL_CTX_set_timeout(ctx.get(), config.d_sessionTimeout);
- }
-
- if (config.d_preferServerCiphers) {
- sslOptions |= SSL_OP_CIPHER_SERVER_PREFERENCE;
-#ifdef SSL_OP_PRIORITIZE_CHACHA
- sslOptions |= SSL_OP_PRIORITIZE_CHACHA;
-#endif /* SSL_OP_PRIORITIZE_CHACHA */
- }
-
- SSL_CTX_set_options(ctx.get(), sslOptions);
- if (!libssl_set_min_tls_version(ctx, config.d_minTLSVersion)) {
- throw std::runtime_error("Failed to set the minimum version to '" + libssl_tls_version_to_string(config.d_minTLSVersion));
- }
-
-#ifdef SSL_CTX_set_ecdh_auto
- SSL_CTX_set_ecdh_auto(ctx.get(), 1);
-#endif
-
- if (config.d_maxStoredSessions == 0) {
- /* disable stored sessions entirely */
- SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_OFF);
- }
- else {
- /* use the internal built-in cache to store sessions */
- SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_SERVER);
- SSL_CTX_sess_set_cache_size(ctx.get(), config.d_maxStoredSessions);
- }
-
- /* we need to set this callback to acknowledge the server name sent by the client,
- otherwise it will not stored in the session and will not be accessible when the
- session is resumed, causing SSL_get_servername to return nullptr */
- SSL_CTX_set_tlsext_servername_callback(ctx.get(), &libssl_server_name_callback);
-
- std::vector<int> keyTypes;
- /* load certificate and private key */
- for (const auto& pair : config.d_certKeyPairs) {
- if (SSL_CTX_use_certificate_chain_file(ctx.get(), pair.first.c_str()) != 1) {
- ERR_print_errors_fp(stderr);
- throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.first);
- }
- if (SSL_CTX_use_PrivateKey_file(ctx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
- ERR_print_errors_fp(stderr);
- throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.second);
- }
- if (SSL_CTX_check_private_key(ctx.get()) != 1) {
- ERR_print_errors_fp(stderr);
- throw std::runtime_error("The key from '" + pair.second + "' does not match the certificate from '" + pair.first + "'");
- }
- /* store the type of the new key, we might need it later to select the right OCSP stapling response */
- auto keyType = libssl_get_last_key_type(ctx);
- if (keyType < 0) {
- throw std::runtime_error("The key from '" + pair.second + "' has an unknown type");
- }
- keyTypes.push_back(keyType);
- }
-
- if (!config.d_ocspFiles.empty()) {
- try {
- ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, keyTypes);
- }
- catch(const std::exception& e) {
- throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what()));
- }
- }
-
- if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) {
- throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers);
- }
-
-#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
- if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) {
- throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13);
- }
-#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
-
- return ctx;
-}
-
-#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
-static void libssl_key_log_file_callback(const SSL* ssl, const char* line)
-{
- SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl);
- if (sslCtx == nullptr) {
- return;
- }
-
- auto fp = reinterpret_cast<FILE*>(SSL_CTX_get_ex_data(sslCtx, s_keyLogIndex));
- if (fp == nullptr) {
- return;
- }
-
- fprintf(fp, "%s\n", line);
- fflush(fp);
-}
-#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */
-
-std::unique_ptr<FILE, int(*)(FILE*)> libssl_set_key_log_file(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::string& logFile)
-{
-#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
- auto fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(logFile.c_str(), "a"), fclose);
- if (!fp) {
- throw std::runtime_error("Error opening TLS log file '" + logFile + "'");
- }
-
- SSL_CTX_set_ex_data(ctx.get(), s_keyLogIndex, fp.get());
- SSL_CTX_set_keylog_callback(ctx.get(), &libssl_key_log_file_callback);
-
- return fp;
-#else
- return std::unique_ptr<FILE, int(*)(FILE*)>(nullptr, fclose);
-#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */
-}
-
-#endif /* HAVE_LIBSSL */
--- /dev/null
+../libssl.cc
\ No newline at end of file
--- /dev/null
+../../../m4/pdns_enable_tls.m4
\ No newline at end of file
--- /dev/null
+../../../m4/pdns_with_gnutls.m4
\ No newline at end of file
--- /dev/null
+../../../m4/pdns_with_libssl.m4
\ No newline at end of file
+++ /dev/null
-
-#include "config.h"
-#include "dolog.hh"
-#include "iputils.hh"
-#include "lock.hh"
-#include "tcpiohandler.hh"
-
-#ifdef HAVE_LIBSODIUM
-#include <sodium.h>
-#endif /* HAVE_LIBSODIUM */
-
-#ifdef HAVE_DNS_OVER_TLS
-#ifdef HAVE_LIBSSL
-#include <openssl/conf.h>
-#include <openssl/err.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-
-#include "libssl.hh"
-
-class OpenSSLFrontendContext
-{
-public:
- OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
- {
- registerOpenSSLUser();
-
- d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses);
- if (!d_tlsCtx) {
- ERR_print_errors_fp(stderr);
- throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
- }
- }
-
- void cleanup()
- {
- d_tlsCtx.reset();
-
- unregisterOpenSSLUser();
- }
-
- OpenSSLTLSTicketKeysRing d_ticketKeys;
- std::map<int, std::string> d_ocspResponses;
- std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
- std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
-};
-
-class OpenSSLTLSConnection: public TLSConnection
-{
-public:
- OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
- {
- d_socket = socket;
-
- if (!s_initTLSConnIndex.test_and_set()) {
- /* not initialized yet */
- s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
- if (s_tlsConnIndex == -1) {
- throw std::runtime_error("Error getting an index for TLS connection data");
- }
- }
-
- if (!d_conn) {
- vinfolog("Error creating TLS object");
- if (g_verbose) {
- ERR_print_errors_fp(stderr);
- }
- throw std::runtime_error("Error creating TLS object");
- }
-
- if (!SSL_set_fd(d_conn.get(), d_socket)) {
- throw std::runtime_error("Error assigning socket");
- }
-
- SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
- }
-
- IOState convertIORequestToIOState(int res) const
- {
- int error = SSL_get_error(d_conn.get(), res);
- if (error == SSL_ERROR_WANT_READ) {
- return IOState::NeedRead;
- }
- else if (error == SSL_ERROR_WANT_WRITE) {
- return IOState::NeedWrite;
- }
- else if (error == SSL_ERROR_SYSCALL) {
- throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno)));
- }
- else {
- throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
- }
- }
-
- void handleIORequest(int res, unsigned int timeout)
- {
- auto state = convertIORequestToIOState(res);
- if (state == IOState::NeedRead) {
- res = waitForData(d_socket, timeout);
- if (res == 0) {
- throw std::runtime_error("Timeout while reading from TLS connection");
- }
- else if (res < 0) {
- throw std::runtime_error("Error waiting to read from TLS connection");
- }
- }
- else if (state == IOState::NeedWrite) {
- res = waitForRWData(d_socket, false, timeout, 0);
- if (res == 0) {
- throw std::runtime_error("Timeout while writing to TLS connection");
- }
- else if (res < 0) {
- throw std::runtime_error("Error waiting to write to TLS connection");
- }
- }
- }
-
- IOState tryHandshake() override
- {
- int res = SSL_accept(d_conn.get());
- if (res == 1) {
- return IOState::Done;
- }
- else if (res < 0) {
- return convertIORequestToIOState(res);
- }
-
- throw std::runtime_error("Error accepting TLS connection");
- }
-
- void doHandshake() override
- {
- int res = 0;
- do {
- res = SSL_accept(d_conn.get());
- if (res < 0) {
- handleIORequest(res, d_timeout);
- }
- }
- while (res < 0);
-
- if (res != 1) {
- throw std::runtime_error("Error accepting TLS connection");
- }
- }
-
- IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override
- {
- do {
- int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
- if (res <= 0) {
- return convertIORequestToIOState(res);
- }
- else {
- pos += static_cast<size_t>(res);
- }
- }
- while (pos < toWrite);
- return IOState::Done;
- }
-
- IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
- {
- do {
- int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
- if (res <= 0) {
- return convertIORequestToIOState(res);
- }
- else {
- pos += static_cast<size_t>(res);
- }
- }
- while (pos < toRead);
- return IOState::Done;
- }
-
- size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
- {
- size_t got = 0;
- time_t start = 0;
- unsigned int remainingTime = totalTimeout;
- if (totalTimeout) {
- start = time(nullptr);
- }
-
- do {
- int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
- if (res <= 0) {
- handleIORequest(res, readTimeout);
- }
- else {
- got += static_cast<size_t>(res);
- }
-
- if (totalTimeout) {
- time_t now = time(nullptr);
- unsigned int elapsed = now - start;
- if (now < start || elapsed >= remainingTime) {
- throw runtime_error("Timeout while reading data");
- }
- start = now;
- remainingTime -= elapsed;
- }
- }
- while (got < bufferSize);
-
- return got;
- }
-
- size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
- {
- size_t got = 0;
- do {
- int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
- if (res <= 0) {
- handleIORequest(res, writeTimeout);
- }
- else {
- got += static_cast<size_t>(res);
- }
- }
- while (got < bufferSize);
-
- return got;
- }
-
- bool hasBufferedData() const override
- {
- if (d_conn) {
- return SSL_pending(d_conn.get()) > 0;
- }
-
- return false;
- }
-
- void close() override
- {
- if (d_conn) {
- SSL_shutdown(d_conn.get());
- }
- }
-
- std::string getServerNameIndication() const override
- {
- if (d_conn) {
- const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
- if (value) {
- return std::string(value);
- }
- }
- return std::string();
- }
-
- LibsslTLSVersion getTLSVersion() const override
- {
- auto proto = SSL_version(d_conn.get());
- switch (proto) {
- case TLS1_VERSION:
- return LibsslTLSVersion::TLS10;
- case TLS1_1_VERSION:
- return LibsslTLSVersion::TLS11;
- case TLS1_2_VERSION:
- return LibsslTLSVersion::TLS12;
-#ifdef TLS1_3_VERSION
- case TLS1_3_VERSION:
- return LibsslTLSVersion::TLS13;
-#endif /* TLS1_3_VERSION */
- default:
- return LibsslTLSVersion::Unknown;
- }
- }
-
- bool hasSessionBeenResumed() const override
- {
- if (d_conn) {
- return SSL_session_reused(d_conn.get()) != 0;
- }
- return false;
- }
-
- static int s_tlsConnIndex;
-
-private:
- static std::atomic_flag s_initTLSConnIndex;
-
- std::shared_ptr<OpenSSLFrontendContext> d_feContext;
- std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
- unsigned int d_timeout;
-};
-
-std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT;
-int OpenSSLTLSConnection::s_tlsConnIndex = -1;
-
-class OpenSSLTLSIOCtx: public TLSCtx
-{
-public:
- OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
- {
- d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
-
- if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
- /* use our own ticket keys handler so we can rotate them */
- SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
- libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
- }
-
- if (!d_feContext->d_ocspResponses.empty()) {
- SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
- SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
- }
-
- libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
-
- if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
- d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
- }
-
- try {
- if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
- handleTicketsKeyRotation(time(nullptr));
- }
- else {
- OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
- }
- }
- catch (const std::exception& e) {
- throw;
- }
- }
-
- ~OpenSSLTLSIOCtx() override
- {
- }
-
- static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
- {
- OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
- if (ctx == nullptr) {
- return -1;
- }
-
- int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
- if (enc == 0) {
- if (ret == 0 || ret == 2) {
- OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex));
- if (conn) {
- if (ret == 0) {
- conn->setUnknownTicketKey();
- }
- else if (ret == 2) {
- conn->setResumedFromInactiveTicketKey();
- }
- }
- }
- }
-
- return ret;
- }
-
- static int ocspStaplingCb(SSL* ssl, void* arg)
- {
- if (ssl == nullptr || arg == nullptr) {
- return SSL_TLSEXT_ERR_NOACK;
- }
- const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
- return libssl_ocsp_stapling_callback(ssl, *ocspMap);
- }
-
- std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
- {
- handleTicketsKeyRotation(now);
-
- return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_feContext));
- }
-
- void rotateTicketsKey(time_t now) override
- {
- d_feContext->d_ticketKeys.rotateTicketsKey(now);
-
- if (d_ticketsKeyRotationDelay > 0) {
- d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
- }
- }
-
- void loadTicketsKeys(const std::string& keyFile) override final
- {
- d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
-
- if (d_ticketsKeyRotationDelay > 0) {
- d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
- }
- }
-
- size_t getTicketsKeysCount() override
- {
- return d_feContext->d_ticketKeys.getKeysCount();
- }
-
-private:
- std::shared_ptr<OpenSSLFrontendContext> d_feContext;
-};
-
-#endif /* HAVE_LIBSSL */
-
-#ifdef HAVE_GNUTLS
-#include <gnutls/gnutls.h>
-#include <gnutls/x509.h>
-
-static void safe_memory_lock(void* data, size_t size)
-{
-#ifdef HAVE_LIBSODIUM
- sodium_mlock(data, size);
-#endif
-}
-
-static void safe_memory_release(void* data, size_t size)
-{
-#ifdef HAVE_LIBSODIUM
- sodium_munlock(data, size);
-#elif defined(HAVE_EXPLICIT_BZERO)
- explicit_bzero(data, size);
-#elif defined(HAVE_EXPLICIT_MEMSET)
- explicit_memset(data, 0, size);
-#elif defined(HAVE_GNUTLS_MEMSET)
- gnutls_memset(data, 0, size);
-#else
- /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
- volatile unsigned int volatile_zero_idx = 0;
- volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
-
- if (size == 0)
- return;
-
- do {
- memset(data, 0, size);
- } while (p[volatile_zero_idx] != 0);
-#endif
-}
-
-class GnuTLSTicketsKey
-{
-public:
- GnuTLSTicketsKey()
- {
- if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error generating tickets key for TLS context");
- }
-
- safe_memory_lock(d_key.data, d_key.size);
- }
-
- GnuTLSTicketsKey(const std::string& keyFile)
- {
- /* to be sure we are loading the correct amount of data, which
- may change between versions, let's generate a correct key first */
- if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
- }
-
- safe_memory_lock(d_key.data, d_key.size);
-
- try {
- ifstream file(keyFile);
- file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
-
- if (file.fail()) {
- file.close();
- throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
- }
-
- file.close();
- }
- catch (const std::exception& e) {
- safe_memory_release(d_key.data, d_key.size);
- gnutls_free(d_key.data);
- d_key.data = nullptr;
- throw;
- }
- }
-
- ~GnuTLSTicketsKey()
- {
- if (d_key.data != nullptr && d_key.size > 0) {
- safe_memory_release(d_key.data, d_key.size);
- }
- gnutls_free(d_key.data);
- d_key.data = nullptr;
- }
- const gnutls_datum_t& getKey() const
- {
- return d_key;
- }
-
-private:
- gnutls_datum_t d_key{nullptr, 0};
-};
-
-class GnuTLSConnection: public TLSConnection
-{
-public:
-
- GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
- {
- unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
-#ifdef GNUTLS_NO_SIGNAL
- sslOptions |= GNUTLS_NO_SIGNAL;
-#endif
-
- d_socket = socket;
-
- gnutls_session_t conn;
- if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error creating TLS connection");
- }
-
- d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
- conn = nullptr;
-
- if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error setting certificate and key to TLS connection");
- }
-
- if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error setting ciphers to TLS connection");
- }
-
- if (enableTickets && d_ticketsKey) {
- const gnutls_datum_t& key = d_ticketsKey->getKey();
- if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error setting the tickets key to TLS connection");
- }
- }
-
- gnutls_transport_set_int(d_conn.get(), d_socket);
-
- /* timeouts are in milliseconds */
- gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
- gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
- }
-
- void doHandshake() override
- {
- int ret = 0;
- do {
- ret = gnutls_handshake(d_conn.get());
- if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
- throw std::runtime_error("Error accepting a new connection");
- }
- }
- while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
- }
-
- IOState tryHandshake() override
- {
- int ret = 0;
-
- do {
- ret = gnutls_handshake(d_conn.get());
- if (ret == GNUTLS_E_SUCCESS) {
- return IOState::Done;
- }
- else if (ret == GNUTLS_E_AGAIN) {
- return IOState::NeedRead;
- }
- else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
- throw std::runtime_error("Error accepting a new connection");
- }
- } while (ret == GNUTLS_E_INTERRUPTED);
-
- throw std::runtime_error("Error accepting a new connection");
- }
-
- IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override
- {
- do {
- ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
- if (res == 0) {
- throw std::runtime_error("Error writing to TLS connection");
- }
- else if (res > 0) {
- pos += static_cast<size_t>(res);
- }
- else if (res < 0) {
- if (gnutls_error_is_fatal(res)) {
- throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
- }
- else if (res == GNUTLS_E_AGAIN) {
- return IOState::NeedWrite;
- }
- warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
- }
- }
- while (pos < toWrite);
- return IOState::Done;
- }
-
- IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
- {
- do {
- ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
- if (res == 0) {
- throw std::runtime_error("Error reading from TLS connection");
- }
- else if (res > 0) {
- pos += static_cast<size_t>(res);
- }
- else if (res < 0) {
- if (gnutls_error_is_fatal(res)) {
- throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
- }
- else if (res == GNUTLS_E_AGAIN) {
- return IOState::NeedRead;
- }
- warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
- }
- }
- while (pos < toRead);
- return IOState::Done;
- }
-
- size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
- {
- size_t got = 0;
- time_t start = 0;
- unsigned int remainingTime = totalTimeout;
- if (totalTimeout) {
- start = time(nullptr);
- }
-
- do {
- ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
- if (res == 0) {
- throw std::runtime_error("Error reading from TLS connection");
- }
- else if (res > 0) {
- got += static_cast<size_t>(res);
- }
- else if (res < 0) {
- if (gnutls_error_is_fatal(res)) {
- throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
- }
- else if (res == GNUTLS_E_AGAIN) {
- int result = waitForData(d_socket, readTimeout);
- if (result <= 0) {
- throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
- }
- }
- else {
- vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
- }
- }
-
- if (totalTimeout) {
- time_t now = time(nullptr);
- unsigned int elapsed = now - start;
- if (now < start || elapsed >= remainingTime) {
- throw runtime_error("Timeout while reading data");
- }
- start = now;
- remainingTime -= elapsed;
- }
- }
- while (got < bufferSize);
-
- return got;
- }
-
- size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
- {
- size_t got = 0;
-
- do {
- ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
- if (res == 0) {
- throw std::runtime_error("Error writing to TLS connection");
- }
- else if (res > 0) {
- got += static_cast<size_t>(res);
- }
- else if (res < 0) {
- if (gnutls_error_is_fatal(res)) {
- throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
- }
- else if (res == GNUTLS_E_AGAIN) {
- int result = waitForRWData(d_socket, false, writeTimeout, 0);
- if (result <= 0) {
- throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
- }
- }
- else {
- vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
- }
- }
- }
- while (got < bufferSize);
-
- return got;
- }
-
- bool hasBufferedData() const override
- {
- if (d_conn) {
- return gnutls_record_check_pending(d_conn.get()) > 0;
- }
-
- return false;
- }
-
- std::string getServerNameIndication() const override
- {
- if (d_conn) {
- unsigned int type;
- size_t name_len = 256;
- std::string sni;
- sni.resize(name_len);
-
- int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
- if (res == GNUTLS_E_SUCCESS) {
- sni.resize(name_len);
- return sni;
- }
- }
- return std::string();
- }
-
- LibsslTLSVersion getTLSVersion() const override
- {
- auto proto = gnutls_protocol_get_version(d_conn.get());
- switch (proto) {
- case GNUTLS_TLS1_0:
- return LibsslTLSVersion::TLS10;
- case GNUTLS_TLS1_1:
- return LibsslTLSVersion::TLS11;
- case GNUTLS_TLS1_2:
- return LibsslTLSVersion::TLS12;
-#if GNUTLS_VERSION_NUMBER >= 0x030603
- case GNUTLS_TLS1_3:
- return LibsslTLSVersion::TLS13;
-#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
- default:
- return LibsslTLSVersion::Unknown;
- }
- }
-
- bool hasSessionBeenResumed() const override
- {
- if (d_conn) {
- return gnutls_session_is_resumed(d_conn.get()) != 0;
- }
- return false;
- }
-
- void close() override
- {
- if (d_conn) {
- gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
- }
- }
-
-private:
- std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
- std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
-};
-
-class GnuTLSIOCtx: public TLSCtx
-{
-public:
- GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
- {
- int rc = 0;
- d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
-
- gnutls_certificate_credentials_t creds;
- rc = gnutls_certificate_allocate_credentials(&creds);
- if (rc != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
- }
-
- d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
- creds = nullptr;
-
- for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
- rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
- if (rc != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
- }
- }
-
- size_t count = 0;
- for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
- rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
- if (rc != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
- }
- ++count;
- }
-
-#if GNUTLS_VERSION_NUMBER >= 0x030600
- rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
- if (rc != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
- }
-#endif
-
- rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
- if (rc != GNUTLS_E_SUCCESS) {
- throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
- }
-
- try {
- if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
- handleTicketsKeyRotation(time(nullptr));
- }
- else {
- GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
- }
- }
- catch(const std::runtime_error& e) {
- throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
- }
- }
-
- virtual ~GnuTLSIOCtx() override
- {
- d_creds.reset();
-
- if (d_priorityCache) {
- gnutls_priority_deinit(d_priorityCache);
- }
- }
-
- std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
- {
- handleTicketsKeyRotation(now);
-
- std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
- {
- ReadLock rl(&d_lock);
- ticketsKey = d_ticketsKey;
- }
-
- return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
- }
-
- void rotateTicketsKey(time_t now) override
- {
- if (!d_enableTickets) {
- return;
- }
-
- auto newKey = std::make_shared<GnuTLSTicketsKey>();
-
- {
- WriteLock wl(&d_lock);
- d_ticketsKey = newKey;
- }
-
- if (d_ticketsKeyRotationDelay > 0) {
- d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
- }
- }
-
- void loadTicketsKeys(const std::string& file) override final
- {
- if (!d_enableTickets) {
- return;
- }
-
- auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
- {
- WriteLock wl(&d_lock);
- d_ticketsKey = newKey;
- }
-
- if (d_ticketsKeyRotationDelay > 0) {
- d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
- }
- }
-
- size_t getTicketsKeysCount() override
- {
- ReadLock rl(&d_lock);
- return d_ticketsKey != nullptr ? 1 : 0;
- }
-
-private:
- std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
- gnutls_priority_t d_priorityCache{nullptr};
- std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
- ReadWriteLock d_lock;
- bool d_enableTickets{true};
-};
-
-#endif /* HAVE_GNUTLS */
-
-#endif /* HAVE_DNS_OVER_TLS */
-
-bool TLSFrontend::setupTLS()
-{
-#ifdef HAVE_DNS_OVER_TLS
- /* get the "best" available provider */
- if (!d_provider.empty()) {
-#ifdef HAVE_GNUTLS
- if (d_provider == "gnutls") {
- d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
- return true;
- }
-#endif /* HAVE_GNUTLS */
-#ifdef HAVE_LIBSSL
- if (d_provider == "openssl") {
- d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
- return true;
- }
-#endif /* HAVE_LIBSSL */
- }
-#ifdef HAVE_LIBSSL
- d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
-#else /* HAVE_LIBSSL */
-#ifdef HAVE_GNUTLS
- d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
-#endif /* HAVE_GNUTLS */
-#endif /* HAVE_LIBSSL */
-
-#endif /* HAVE_DNS_OVER_TLS */
- return true;
-}
--- /dev/null
+../tcpiohandler.cc
\ No newline at end of file
--- /dev/null
+
+#include "config.h"
+#include "libssl.hh"
+
+#ifdef HAVE_LIBSSL
+
+#include <atomic>
+#include <fstream>
+#include <cstring>
+#include <mutex>
+#include <pthread.h>
+
+#include <openssl/conf.h>
+#include <openssl/err.h>
+#include <openssl/ocsp.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+
+#ifdef HAVE_LIBSODIUM
+#include <sodium.h>
+#endif /* HAVE_LIBSODIUM */
+
+#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL)
+/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
+
+#include "lock.hh"
+static std::vector<std::mutex> openssllocks;
+
+extern "C" {
+static void openssl_pthreads_locking_callback(int mode, int type, const char *file, int line)
+{
+ if (mode & CRYPTO_LOCK) {
+ openssllocks.at(type).lock();
+
+ } else {
+ openssllocks.at(type).unlock();
+ }
+}
+
+static unsigned long openssl_pthreads_id_callback()
+{
+ return (unsigned long)pthread_self();
+}
+}
+
+static void openssl_thread_setup()
+{
+ openssllocks = std::vector<std::mutex>(CRYPTO_num_locks());
+ CRYPTO_set_id_callback(&openssl_pthreads_id_callback);
+ CRYPTO_set_locking_callback(&openssl_pthreads_locking_callback);
+}
+
+static void openssl_thread_cleanup()
+{
+ CRYPTO_set_locking_callback(nullptr);
+ openssllocks.clear();
+}
+
+#endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) */
+
+static std::atomic<uint64_t> s_users;
+static int s_ticketsKeyIndex{-1};
+static int s_countersIndex{-1};
+static int s_keyLogIndex{-1};
+
+void registerOpenSSLUser()
+{
+ if (s_users.fetch_add(1) == 0) {
+#ifdef HAVE_OPENSSL_INIT_CRYPTO
+ /* load the default configuration file (or one specified via OPENSSL_CONF),
+ which can then be used to load engines */
+ OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CONFIG, nullptr);
+#endif
+
+#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL))
+ SSL_load_error_strings();
+ OpenSSL_add_ssl_algorithms();
+ openssl_thread_setup();
+#endif
+ s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+
+ if (s_ticketsKeyIndex == -1) {
+ throw std::runtime_error("Error getting an index for tickets key");
+ }
+
+ s_countersIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+
+ if (s_countersIndex == -1) {
+ throw std::runtime_error("Error getting an index for counters");
+ }
+
+ s_keyLogIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+
+ if (s_keyLogIndex == -1) {
+ throw std::runtime_error("Error getting an index for TLS key logging");
+ }
+ }
+}
+
+void unregisterOpenSSLUser()
+{
+ if (s_users.fetch_sub(1) == 1) {
+#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER && LIBRESSL_VERSION_NUMBER < 0x2090100fL))
+ ERR_free_strings();
+
+ EVP_cleanup();
+
+ CONF_modules_finish();
+ CONF_modules_free();
+ CONF_modules_unload(1);
+
+ CRYPTO_cleanup_all_ex_data();
+ openssl_thread_cleanup();
+#endif
+ }
+}
+
+void* libssl_get_ticket_key_callback_data(SSL* s)
+{
+ SSL_CTX* sslCtx = SSL_get_SSL_CTX(s);
+ if (sslCtx == nullptr) {
+ return nullptr;
+ }
+
+ return SSL_CTX_get_ex_data(sslCtx, s_ticketsKeyIndex);
+}
+
+void libssl_set_ticket_key_callback_data(SSL_CTX* ctx, void* data)
+{
+ SSL_CTX_set_ex_data(ctx, s_ticketsKeyIndex, data);
+}
+
+int libssl_ticket_key_callback(SSL *s, OpenSSLTLSTicketKeysRing& keyring, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
+{
+ if (enc) {
+ const auto key = keyring.getEncryptionKey();
+ if (key == nullptr) {
+ return -1;
+ }
+
+ return key->encrypt(keyName, iv, ectx, hctx);
+ }
+
+ bool activeEncryptionKey = false;
+
+ const auto key = keyring.getDecryptionKey(keyName, activeEncryptionKey);
+ if (key == nullptr) {
+ /* we don't know this key, just create a new ticket */
+ return 0;
+ }
+
+ if (key->decrypt(iv, ectx, hctx) == false) {
+ return -1;
+ }
+
+ if (!activeEncryptionKey) {
+ /* this key is not active, please encrypt the ticket content with the currently active one */
+ return 2;
+ }
+
+ return 1;
+}
+
+static long libssl_server_name_callback(SSL* ssl, int* al, void* arg)
+{
+ (void) al;
+ (void) arg;
+
+ if (SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name)) {
+ return SSL_TLSEXT_ERR_OK;
+ }
+
+ return SSL_TLSEXT_ERR_NOACK;
+}
+
+static void libssl_info_callback(const SSL *ssl, int where, int ret)
+{
+ SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl);
+ if (sslCtx == nullptr) {
+ return;
+ }
+
+ TLSErrorCounters* counters = reinterpret_cast<TLSErrorCounters*>(SSL_CTX_get_ex_data(sslCtx, s_countersIndex));
+ if (counters == nullptr) {
+ return;
+ }
+
+ if (where & SSL_CB_ALERT) {
+ const long lastError = ERR_peek_last_error();
+ switch (ERR_GET_REASON(lastError)) {
+#ifdef SSL_R_DH_KEY_TOO_SMALL
+ case SSL_R_DH_KEY_TOO_SMALL:
+ ++counters->d_dhKeyTooSmall;
+ break;
+#endif /* SSL_R_DH_KEY_TOO_SMALL */
+ case SSL_R_NO_SHARED_CIPHER:
+ ++counters->d_noSharedCipher;
+ break;
+ case SSL_R_UNKNOWN_PROTOCOL:
+ ++counters->d_unknownProtocol;
+ break;
+ case SSL_R_UNSUPPORTED_PROTOCOL:
+#ifdef SSL_R_VERSION_TOO_LOW
+ case SSL_R_VERSION_TOO_LOW:
+#endif /* SSL_R_VERSION_TOO_LOW */
+ ++counters->d_unsupportedProtocol;
+ break;
+ case SSL_R_INAPPROPRIATE_FALLBACK:
+ ++counters->d_inappropriateFallBack;
+ break;
+ case SSL_R_UNKNOWN_CIPHER_TYPE:
+ ++counters->d_unknownCipherType;
+ break;
+ case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE:
+ ++counters->d_unknownKeyExchangeType;
+ break;
+ case SSL_R_UNSUPPORTED_ELLIPTIC_CURVE:
+ ++counters->d_unsupportedEC;
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+void libssl_set_error_counters_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, TLSErrorCounters* counters)
+{
+ SSL_CTX_set_ex_data(ctx.get(), s_countersIndex, counters);
+ SSL_CTX_set_info_callback(ctx.get(), libssl_info_callback);
+}
+
+int libssl_ocsp_stapling_callback(SSL* ssl, const std::map<int, std::string>& ocspMap)
+{
+ auto pkey = SSL_get_privatekey(ssl);
+ if (pkey == nullptr) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+
+ /* look for an OCSP response for the corresponding private key type (RSA, ECDSA..) */
+ const auto& data = ocspMap.find(EVP_PKEY_base_id(pkey));
+ if (data == ocspMap.end()) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+
+ /* we need to allocate a copy because OpenSSL will free the pointer passed to SSL_set_tlsext_status_ocsp_resp() */
+ void* copy = OPENSSL_malloc(data->second.size());
+ if (copy == nullptr) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+
+ memcpy(copy, data->second.data(), data->second.size());
+ SSL_set_tlsext_status_ocsp_resp(ssl, copy, data->second.size());
+ return SSL_TLSEXT_ERR_OK;
+}
+
+static bool libssl_validate_ocsp_response(const std::string& response)
+{
+ auto responsePtr = reinterpret_cast<const unsigned char *>(response.data());
+ std::unique_ptr<OCSP_RESPONSE, void(*)(OCSP_RESPONSE*)> resp(d2i_OCSP_RESPONSE(nullptr, &responsePtr, response.size()), OCSP_RESPONSE_free);
+ if (resp == nullptr) {
+ throw std::runtime_error("Unable to parse OCSP response");
+ }
+
+ int status = OCSP_response_status(resp.get());
+ if (status != OCSP_RESPONSE_STATUS_SUCCESSFUL) {
+ throw std::runtime_error("OCSP response status is not successful: " + std::to_string(status));
+ }
+
+ std::unique_ptr<OCSP_BASICRESP, void(*)(OCSP_BASICRESP*)> basic(OCSP_response_get1_basic(resp.get()), OCSP_BASICRESP_free);
+ if (basic == nullptr) {
+ throw std::runtime_error("Error getting a basic OCSP response");
+ }
+
+ if (OCSP_resp_count(basic.get()) != 1) {
+ throw std::runtime_error("More than one single response in an OCSP basic response");
+ }
+
+ auto singleResponse = OCSP_resp_get0(basic.get(), 0);
+ if (singleResponse == nullptr) {
+ throw std::runtime_error("Error getting a single response from the basic OCSP response");
+ }
+
+ int reason;
+ ASN1_GENERALIZEDTIME* revTime = nullptr;
+ ASN1_GENERALIZEDTIME* thisUpdate = nullptr;
+ ASN1_GENERALIZEDTIME* nextUpdate = nullptr;
+
+ auto singleResponseStatus = OCSP_single_get0_status(singleResponse, &reason, &revTime, &thisUpdate, &nextUpdate);
+ if (singleResponseStatus != V_OCSP_CERTSTATUS_GOOD) {
+ throw std::runtime_error("Invalid status for OCSP single response (" + std::to_string(singleResponseStatus) + ")");
+ }
+ if (thisUpdate == nullptr || nextUpdate == nullptr) {
+ throw std::runtime_error("Error getting validity of OCSP single response");
+ }
+
+ auto validityResult = OCSP_check_validity(thisUpdate, nextUpdate, /* 5 minutes of leeway */ 5 * 60, -1);
+ if (validityResult == 0) {
+ throw std::runtime_error("OCSP single response is not yet, or no longer, valid");
+ }
+
+ return true;
+}
+
+std::map<int, std::string> libssl_load_ocsp_responses(const std::vector<std::string>& ocspFiles, std::vector<int> keyTypes)
+{
+ std::map<int, std::string> ocspResponses;
+
+ if (ocspFiles.size() > keyTypes.size()) {
+ throw std::runtime_error("More OCSP files than certificates and keys loaded!");
+ }
+
+ size_t count = 0;
+ for (const auto& filename : ocspFiles) {
+ std::ifstream file(filename, std::ios::binary);
+ std::string content;
+ while(file) {
+ char buffer[4096];
+ file.read(buffer, sizeof(buffer));
+ if (file.bad()) {
+ file.close();
+ throw std::runtime_error("Unable to load OCSP response from '" + filename + "'");
+ }
+ content.append(buffer, file.gcount());
+ }
+ file.close();
+
+ try {
+ libssl_validate_ocsp_response(content);
+ ocspResponses.insert({keyTypes.at(count), std::move(content)});
+ }
+ catch (const std::exception& e) {
+ throw std::runtime_error("Error checking the validity of OCSP response from '" + filename + "': " + e.what());
+ }
+ ++count;
+ }
+
+ return ocspResponses;
+}
+
+int libssl_get_last_key_type(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx)
+{
+#ifdef HAVE_SSL_CTX_GET0_PRIVATEKEY
+ auto pkey = SSL_CTX_get0_privatekey(ctx.get());
+#else
+ auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(ctx.get()), SSL_free);
+ if (!temp) {
+ return -1;
+ }
+ auto pkey = SSL_get_privatekey(temp.get());
+#endif
+
+ if (!pkey) {
+ return -1;
+ }
+
+ return EVP_PKEY_base_id(pkey);
+}
+
+#ifdef HAVE_OCSP_BASIC_SIGN
+bool libssl_generate_ocsp_response(const std::string& certFile, const std::string& caCert, const std::string& caKey, const std::string& outFile, int ndays, int nmin)
+{
+ const EVP_MD* rmd = EVP_sha256();
+
+ auto fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(certFile.c_str(), "r"), fclose);
+ if (!fp) {
+ throw std::runtime_error("Unable to open '" + certFile + "' when loading the certificate to generate an OCSP response");
+ }
+ auto cert = std::unique_ptr<X509, void(*)(X509*)>(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free);
+
+ fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(caCert.c_str(), "r"), fclose);
+ if (!fp) {
+ throw std::runtime_error("Unable to open '" + caCert + "' when loading the issuer certificate to generate an OCSP response");
+ }
+ auto issuer = std::unique_ptr<X509, void(*)(X509*)>(PEM_read_X509_AUX(fp.get(), nullptr, nullptr, nullptr), X509_free);
+ fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(caKey.c_str(), "r"), fclose);
+ if (!fp) {
+ throw std::runtime_error("Unable to open '" + caKey + "' when loading the issuer key to generate an OCSP response");
+ }
+ auto issuerKey = std::unique_ptr<EVP_PKEY, void(*)(EVP_PKEY*)>(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr), EVP_PKEY_free);
+ fp.reset();
+
+ auto bs = std::unique_ptr<OCSP_BASICRESP, void(*)(OCSP_BASICRESP*)>(OCSP_BASICRESP_new(), OCSP_BASICRESP_free);
+ auto thisupd = std::unique_ptr<ASN1_TIME, void(*)(ASN1_TIME*)>(X509_gmtime_adj(nullptr, 0), ASN1_TIME_free);
+ auto nextupd = std::unique_ptr<ASN1_TIME, void(*)(ASN1_TIME*)>(X509_time_adj_ex(nullptr, ndays, nmin * 60, nullptr), ASN1_TIME_free);
+
+ auto cid = std::unique_ptr<OCSP_CERTID, void(*)(OCSP_CERTID*)>(OCSP_cert_to_id(rmd, cert.get(), issuer.get()), OCSP_CERTID_free);
+ OCSP_basic_add1_status(bs.get(), cid.get(), V_OCSP_CERTSTATUS_GOOD, 0, nullptr, thisupd.get(), nextupd.get());
+
+ if (OCSP_basic_sign(bs.get(), issuer.get(), issuerKey.get(), rmd, nullptr, OCSP_NOCERTS) != 1) {
+ throw std::runtime_error("Error while signing the OCSP response");
+ }
+
+ auto resp = std::unique_ptr<OCSP_RESPONSE, void(*)(OCSP_RESPONSE*)>(OCSP_response_create(OCSP_RESPONSE_STATUS_SUCCESSFUL, bs.get()), OCSP_RESPONSE_free);
+ auto bio = std::unique_ptr<BIO, void(*)(BIO*)>(BIO_new_file(outFile.c_str(), "wb"), BIO_vfree);
+ if (!bio) {
+ throw std::runtime_error("Error opening file for writing the OCSP response");
+ }
+
+ // i2d_OCSP_RESPONSE_bio(bio.get(), resp.get()) is unusable from C++ because of an invalid cast
+ ASN1_i2d_bio((i2d_of_void*)i2d_OCSP_RESPONSE, bio.get(), (unsigned char*)resp.get());
+
+ return true;
+}
+#endif /* HAVE_OCSP_BASIC_SIGN */
+
+LibsslTLSVersion libssl_tls_version_from_string(const std::string& str)
+{
+ if (str == "tls1.0") {
+ return LibsslTLSVersion::TLS10;
+ }
+ if (str == "tls1.1") {
+ return LibsslTLSVersion::TLS11;
+ }
+ if (str == "tls1.2") {
+ return LibsslTLSVersion::TLS12;
+ }
+ if (str == "tls1.3") {
+ return LibsslTLSVersion::TLS13;
+ }
+ throw std::runtime_error("Unknown TLS version '" + str);
+}
+
+const std::string& libssl_tls_version_to_string(LibsslTLSVersion version)
+{
+ static const std::map<LibsslTLSVersion, std::string> versions = {
+ { LibsslTLSVersion::TLS10, "tls1.0" },
+ { LibsslTLSVersion::TLS11, "tls1.1" },
+ { LibsslTLSVersion::TLS12, "tls1.2" },
+ { LibsslTLSVersion::TLS13, "tls1.3" }
+ };
+
+ const auto& it = versions.find(version);
+ if (it == versions.end()) {
+ throw std::runtime_error("Unknown TLS version (" + std::to_string((int)version) + ")");
+ }
+ return it->second;
+}
+
+bool libssl_set_min_tls_version(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, LibsslTLSVersion version)
+{
+#if defined(HAVE_SSL_CTX_SET_MIN_PROTO_VERSION) || defined(SSL_CTX_set_min_proto_version)
+ /* These functions have been introduced in 1.1.0, and the use of SSL_OP_NO_* is deprecated
+ Warning: SSL_CTX_set_min_proto_version is a function-like macro in OpenSSL */
+ int vers;
+ switch(version) {
+ case LibsslTLSVersion::TLS10:
+ vers = TLS1_VERSION;
+ break;
+ case LibsslTLSVersion::TLS11:
+ vers = TLS1_1_VERSION;
+ break;
+ case LibsslTLSVersion::TLS12:
+ vers = TLS1_2_VERSION;
+ break;
+ case LibsslTLSVersion::TLS13:
+#ifdef TLS1_3_VERSION
+ vers = TLS1_3_VERSION;
+#else
+ return false;
+#endif /* TLS1_3_VERSION */
+ break;
+ default:
+ return false;
+ }
+
+ if (SSL_CTX_set_min_proto_version(ctx.get(), vers) != 1) {
+ return false;
+ }
+ return true;
+#else
+ long vers = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
+ switch(version) {
+ case LibsslTLSVersion::TLS10:
+ break;
+ case LibsslTLSVersion::TLS11:
+ vers |= SSL_OP_NO_TLSv1;
+ break;
+ case LibsslTLSVersion::TLS12:
+ vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1;
+ break;
+ case LibsslTLSVersion::TLS13:
+ vers |= SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2;
+ break;
+ default:
+ return false;
+ }
+
+ long options = SSL_CTX_get_options(ctx.get());
+ SSL_CTX_set_options(ctx.get(), options | vers);
+ return true;
+#endif
+}
+
+OpenSSLTLSTicketKeysRing::OpenSSLTLSTicketKeysRing(size_t capacity)
+{
+ d_ticketKeys.set_capacity(capacity);
+}
+
+OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing()
+{
+}
+
+void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr<OpenSSLTLSTicketKey> newKey)
+{
+ WriteLock wl(&d_lock);
+ d_ticketKeys.push_front(newKey);
+}
+
+std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getEncryptionKey()
+{
+ ReadLock rl(&d_lock);
+ return d_ticketKeys.front();
+}
+
+std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey)
+{
+ ReadLock rl(&d_lock);
+ for (auto& key : d_ticketKeys) {
+ if (key->nameMatches(name)) {
+ activeKey = (key == d_ticketKeys.front());
+ return key;
+ }
+ }
+ return nullptr;
+}
+
+size_t OpenSSLTLSTicketKeysRing::getKeysCount()
+{
+ ReadLock rl(&d_lock);
+ return d_ticketKeys.size();
+}
+
+void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile)
+{
+ bool keyLoaded = false;
+ std::ifstream file(keyFile);
+ try {
+ do {
+ auto newKey = std::make_shared<OpenSSLTLSTicketKey>(file);
+ addKey(newKey);
+ keyLoaded = true;
+ }
+ while (!file.fail());
+ }
+ catch (const std::exception& e) {
+ /* if we haven't been able to load at least one key, fail */
+ if (!keyLoaded) {
+ throw;
+ }
+ }
+
+ file.close();
+}
+
+void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t now)
+{
+ auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
+ addKey(newKey);
+}
+
+OpenSSLTLSTicketKey::OpenSSLTLSTicketKey()
+{
+ if (RAND_bytes(d_name, sizeof(d_name)) != 1) {
+ throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
+ }
+
+ if (RAND_bytes(d_cipherKey, sizeof(d_cipherKey)) != 1) {
+ throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
+ }
+
+ if (RAND_bytes(d_hmacKey, sizeof(d_hmacKey)) != 1) {
+ throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
+ }
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_name, sizeof(d_name));
+ sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+}
+
+OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(ifstream& file)
+{
+ file.read(reinterpret_cast<char*>(d_name), sizeof(d_name));
+ file.read(reinterpret_cast<char*>(d_cipherKey), sizeof(d_cipherKey));
+ file.read(reinterpret_cast<char*>(d_hmacKey), sizeof(d_hmacKey));
+
+ if (file.fail()) {
+ throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
+ }
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_name, sizeof(d_name));
+ sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+}
+
+OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey()
+{
+#ifdef HAVE_LIBSODIUM
+ sodium_munlock(d_name, sizeof(d_name));
+ sodium_munlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_munlock(d_hmacKey, sizeof(d_hmacKey));
+#else
+ OPENSSL_cleanse(d_name, sizeof(d_name));
+ OPENSSL_cleanse(d_cipherKey, sizeof(d_cipherKey));
+ OPENSSL_cleanse(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+}
+
+bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const
+{
+ return (memcmp(d_name, name, sizeof(d_name)) == 0);
+}
+
+int OpenSSLTLSTicketKey::encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
+{
+ memcpy(keyName, d_name, sizeof(d_name));
+
+ if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) != 1) {
+ return -1;
+ }
+
+ if (EVP_EncryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
+ return -1;
+ }
+
+ if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
+ return -1;
+ }
+
+ return 1;
+}
+
+bool OpenSSLTLSTicketKey::decrypt(const unsigned char* iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx) const
+{
+ if (HMAC_Init_ex(hctx, d_hmacKey, sizeof(d_hmacKey), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
+ return false;
+ }
+
+ if (EVP_DecryptInit_ex(ectx, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey, iv) != 1) {
+ return false;
+ }
+
+ return true;
+}
+
+std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> libssl_init_server_context(const TLSConfig& config,
+ std::map<int, std::string>& ocspResponses)
+{
+ auto ctx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free);
+
+ int sslOptions =
+ SSL_OP_NO_SSLv2 |
+ SSL_OP_NO_SSLv3 |
+ SSL_OP_NO_COMPRESSION |
+ SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
+ SSL_OP_SINGLE_DH_USE |
+ SSL_OP_SINGLE_ECDH_USE;
+
+ if (!config.d_enableTickets || config.d_numberOfTicketsKeys == 0) {
+ /* for TLS 1.3 this means no stateless tickets, but stateful tickets might still be issued,
+ which is something we don't want. */
+ sslOptions |= SSL_OP_NO_TICKET;
+ /* really disable all tickets */
+#ifdef HAVE_SSL_CTX_SET_NUM_TICKETS
+ SSL_CTX_set_num_tickets(ctx.get(), 0);
+#endif /* HAVE_SSL_CTX_SET_NUM_TICKETS */
+ }
+
+ if (config.d_sessionTimeout > 0) {
+ SSL_CTX_set_timeout(ctx.get(), config.d_sessionTimeout);
+ }
+
+ if (config.d_preferServerCiphers) {
+ sslOptions |= SSL_OP_CIPHER_SERVER_PREFERENCE;
+#ifdef SSL_OP_PRIORITIZE_CHACHA
+ sslOptions |= SSL_OP_PRIORITIZE_CHACHA;
+#endif /* SSL_OP_PRIORITIZE_CHACHA */
+ }
+
+ SSL_CTX_set_options(ctx.get(), sslOptions);
+ if (!libssl_set_min_tls_version(ctx, config.d_minTLSVersion)) {
+ throw std::runtime_error("Failed to set the minimum version to '" + libssl_tls_version_to_string(config.d_minTLSVersion));
+ }
+
+#ifdef SSL_CTX_set_ecdh_auto
+ SSL_CTX_set_ecdh_auto(ctx.get(), 1);
+#endif
+
+ if (config.d_maxStoredSessions == 0) {
+ /* disable stored sessions entirely */
+ SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_OFF);
+ }
+ else {
+ /* use the internal built-in cache to store sessions */
+ SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_SERVER);
+ SSL_CTX_sess_set_cache_size(ctx.get(), config.d_maxStoredSessions);
+ }
+
+ /* we need to set this callback to acknowledge the server name sent by the client,
+ otherwise it will not stored in the session and will not be accessible when the
+ session is resumed, causing SSL_get_servername to return nullptr */
+ SSL_CTX_set_tlsext_servername_callback(ctx.get(), &libssl_server_name_callback);
+
+ std::vector<int> keyTypes;
+ /* load certificate and private key */
+ for (const auto& pair : config.d_certKeyPairs) {
+ if (SSL_CTX_use_certificate_chain_file(ctx.get(), pair.first.c_str()) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("An error occurred while trying to load the TLS server certificate file: " + pair.first);
+ }
+ if (SSL_CTX_use_PrivateKey_file(ctx.get(), pair.second.c_str(), SSL_FILETYPE_PEM) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("An error occurred while trying to load the TLS server private key file: " + pair.second);
+ }
+ if (SSL_CTX_check_private_key(ctx.get()) != 1) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("The key from '" + pair.second + "' does not match the certificate from '" + pair.first + "'");
+ }
+ /* store the type of the new key, we might need it later to select the right OCSP stapling response */
+ auto keyType = libssl_get_last_key_type(ctx);
+ if (keyType < 0) {
+ throw std::runtime_error("The key from '" + pair.second + "' has an unknown type");
+ }
+ keyTypes.push_back(keyType);
+ }
+
+ if (!config.d_ocspFiles.empty()) {
+ try {
+ ocspResponses = libssl_load_ocsp_responses(config.d_ocspFiles, keyTypes);
+ }
+ catch(const std::exception& e) {
+ throw std::runtime_error("Unable to load OCSP responses: " + std::string(e.what()));
+ }
+ }
+
+ if (!config.d_ciphers.empty() && SSL_CTX_set_cipher_list(ctx.get(), config.d_ciphers.c_str()) != 1) {
+ throw std::runtime_error("The TLS ciphers could not be set: " + config.d_ciphers);
+ }
+
+#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
+ if (!config.d_ciphers13.empty() && SSL_CTX_set_ciphersuites(ctx.get(), config.d_ciphers13.c_str()) != 1) {
+ throw std::runtime_error("The TLS 1.3 ciphers could not be set: " + config.d_ciphers13);
+ }
+#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
+
+ return ctx;
+}
+
+#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
+static void libssl_key_log_file_callback(const SSL* ssl, const char* line)
+{
+ SSL_CTX* sslCtx = SSL_get_SSL_CTX(ssl);
+ if (sslCtx == nullptr) {
+ return;
+ }
+
+ auto fp = reinterpret_cast<FILE*>(SSL_CTX_get_ex_data(sslCtx, s_keyLogIndex));
+ if (fp == nullptr) {
+ return;
+ }
+
+ fprintf(fp, "%s\n", line);
+ fflush(fp);
+}
+#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */
+
+std::unique_ptr<FILE, int(*)(FILE*)> libssl_set_key_log_file(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::string& logFile)
+{
+#ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
+ auto fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(logFile.c_str(), "a"), fclose);
+ if (!fp) {
+ throw std::runtime_error("Error opening TLS log file '" + logFile + "'");
+ }
+
+ SSL_CTX_set_ex_data(ctx.get(), s_keyLogIndex, fp.get());
+ SSL_CTX_set_keylog_callback(ctx.get(), &libssl_key_log_file_callback);
+
+ return fp;
+#else
+ return std::unique_ptr<FILE, int(*)(FILE*)>(nullptr, fclose);
+#endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */
+}
+
+#endif /* HAVE_LIBSSL */
--- /dev/null
+
+#include "config.h"
+#include "dolog.hh"
+#include "iputils.hh"
+#include "lock.hh"
+#include "tcpiohandler.hh"
+
+#ifdef HAVE_LIBSODIUM
+#include <sodium.h>
+#endif /* HAVE_LIBSODIUM */
+
+#ifdef HAVE_DNS_OVER_TLS
+#ifdef HAVE_LIBSSL
+#include <openssl/conf.h>
+#include <openssl/err.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+
+#include "libssl.hh"
+
+class OpenSSLFrontendContext
+{
+public:
+ OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
+ {
+ registerOpenSSLUser();
+
+ d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses);
+ if (!d_tlsCtx) {
+ ERR_print_errors_fp(stderr);
+ throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
+ }
+ }
+
+ void cleanup()
+ {
+ d_tlsCtx.reset();
+
+ unregisterOpenSSLUser();
+ }
+
+ OpenSSLTLSTicketKeysRing d_ticketKeys;
+ std::map<int, std::string> d_ocspResponses;
+ std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
+ std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
+};
+
+class OpenSSLTLSConnection: public TLSConnection
+{
+public:
+ OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
+ {
+ d_socket = socket;
+
+ if (!s_initTLSConnIndex.test_and_set()) {
+ /* not initialized yet */
+ s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+ if (s_tlsConnIndex == -1) {
+ throw std::runtime_error("Error getting an index for TLS connection data");
+ }
+ }
+
+ if (!d_conn) {
+ vinfolog("Error creating TLS object");
+ if (g_verbose) {
+ ERR_print_errors_fp(stderr);
+ }
+ throw std::runtime_error("Error creating TLS object");
+ }
+
+ if (!SSL_set_fd(d_conn.get(), d_socket)) {
+ throw std::runtime_error("Error assigning socket");
+ }
+
+ SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
+ }
+
+ IOState convertIORequestToIOState(int res) const
+ {
+ int error = SSL_get_error(d_conn.get(), res);
+ if (error == SSL_ERROR_WANT_READ) {
+ return IOState::NeedRead;
+ }
+ else if (error == SSL_ERROR_WANT_WRITE) {
+ return IOState::NeedWrite;
+ }
+ else if (error == SSL_ERROR_SYSCALL) {
+ throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno)));
+ }
+ else {
+ throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
+ }
+ }
+
+ void handleIORequest(int res, unsigned int timeout)
+ {
+ auto state = convertIORequestToIOState(res);
+ if (state == IOState::NeedRead) {
+ res = waitForData(d_socket, timeout);
+ if (res == 0) {
+ throw std::runtime_error("Timeout while reading from TLS connection");
+ }
+ else if (res < 0) {
+ throw std::runtime_error("Error waiting to read from TLS connection");
+ }
+ }
+ else if (state == IOState::NeedWrite) {
+ res = waitForRWData(d_socket, false, timeout, 0);
+ if (res == 0) {
+ throw std::runtime_error("Timeout while writing to TLS connection");
+ }
+ else if (res < 0) {
+ throw std::runtime_error("Error waiting to write to TLS connection");
+ }
+ }
+ }
+
+ IOState tryHandshake() override
+ {
+ int res = SSL_accept(d_conn.get());
+ if (res == 1) {
+ return IOState::Done;
+ }
+ else if (res < 0) {
+ return convertIORequestToIOState(res);
+ }
+
+ throw std::runtime_error("Error accepting TLS connection");
+ }
+
+ void doHandshake() override
+ {
+ int res = 0;
+ do {
+ res = SSL_accept(d_conn.get());
+ if (res < 0) {
+ handleIORequest(res, d_timeout);
+ }
+ }
+ while (res < 0);
+
+ if (res != 1) {
+ throw std::runtime_error("Error accepting TLS connection");
+ }
+ }
+
+ IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+ {
+ do {
+ int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
+ if (res <= 0) {
+ return convertIORequestToIOState(res);
+ }
+ else {
+ pos += static_cast<size_t>(res);
+ }
+ }
+ while (pos < toWrite);
+ return IOState::Done;
+ }
+
+ IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
+ {
+ do {
+ int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
+ if (res <= 0) {
+ return convertIORequestToIOState(res);
+ }
+ else {
+ pos += static_cast<size_t>(res);
+ }
+ }
+ while (pos < toRead);
+ return IOState::Done;
+ }
+
+ size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
+ {
+ size_t got = 0;
+ time_t start = 0;
+ unsigned int remainingTime = totalTimeout;
+ if (totalTimeout) {
+ start = time(nullptr);
+ }
+
+ do {
+ int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
+ if (res <= 0) {
+ handleIORequest(res, readTimeout);
+ }
+ else {
+ got += static_cast<size_t>(res);
+ }
+
+ if (totalTimeout) {
+ time_t now = time(nullptr);
+ unsigned int elapsed = now - start;
+ if (now < start || elapsed >= remainingTime) {
+ throw runtime_error("Timeout while reading data");
+ }
+ start = now;
+ remainingTime -= elapsed;
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
+ {
+ size_t got = 0;
+ do {
+ int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
+ if (res <= 0) {
+ handleIORequest(res, writeTimeout);
+ }
+ else {
+ got += static_cast<size_t>(res);
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ bool hasBufferedData() const override
+ {
+ if (d_conn) {
+ return SSL_pending(d_conn.get()) > 0;
+ }
+
+ return false;
+ }
+
+ void close() override
+ {
+ if (d_conn) {
+ SSL_shutdown(d_conn.get());
+ }
+ }
+
+ std::string getServerNameIndication() const override
+ {
+ if (d_conn) {
+ const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
+ if (value) {
+ return std::string(value);
+ }
+ }
+ return std::string();
+ }
+
+ LibsslTLSVersion getTLSVersion() const override
+ {
+ auto proto = SSL_version(d_conn.get());
+ switch (proto) {
+ case TLS1_VERSION:
+ return LibsslTLSVersion::TLS10;
+ case TLS1_1_VERSION:
+ return LibsslTLSVersion::TLS11;
+ case TLS1_2_VERSION:
+ return LibsslTLSVersion::TLS12;
+#ifdef TLS1_3_VERSION
+ case TLS1_3_VERSION:
+ return LibsslTLSVersion::TLS13;
+#endif /* TLS1_3_VERSION */
+ default:
+ return LibsslTLSVersion::Unknown;
+ }
+ }
+
+ bool hasSessionBeenResumed() const override
+ {
+ if (d_conn) {
+ return SSL_session_reused(d_conn.get()) != 0;
+ }
+ return false;
+ }
+
+ static int s_tlsConnIndex;
+
+private:
+ static std::atomic_flag s_initTLSConnIndex;
+
+ std::shared_ptr<OpenSSLFrontendContext> d_feContext;
+ std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
+ unsigned int d_timeout;
+};
+
+std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT;
+int OpenSSLTLSConnection::s_tlsConnIndex = -1;
+
+class OpenSSLTLSIOCtx: public TLSCtx
+{
+public:
+ OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
+ {
+ d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
+
+ if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
+ /* use our own ticket keys handler so we can rotate them */
+ SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
+ libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
+ }
+
+ if (!d_feContext->d_ocspResponses.empty()) {
+ SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
+ SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
+ }
+
+ libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
+
+ if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
+ d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
+ }
+
+ try {
+ if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
+ handleTicketsKeyRotation(time(nullptr));
+ }
+ else {
+ OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
+ }
+ }
+ catch (const std::exception& e) {
+ throw;
+ }
+ }
+
+ ~OpenSSLTLSIOCtx() override
+ {
+ }
+
+ static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
+ {
+ OpenSSLFrontendContext* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
+ if (ctx == nullptr) {
+ return -1;
+ }
+
+ int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
+ if (enc == 0) {
+ if (ret == 0 || ret == 2) {
+ OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex));
+ if (conn) {
+ if (ret == 0) {
+ conn->setUnknownTicketKey();
+ }
+ else if (ret == 2) {
+ conn->setResumedFromInactiveTicketKey();
+ }
+ }
+ }
+ }
+
+ return ret;
+ }
+
+ static int ocspStaplingCb(SSL* ssl, void* arg)
+ {
+ if (ssl == nullptr || arg == nullptr) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+ const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
+ return libssl_ocsp_stapling_callback(ssl, *ocspMap);
+ }
+
+ std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
+ {
+ handleTicketsKeyRotation(now);
+
+ return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_feContext));
+ }
+
+ void rotateTicketsKey(time_t now) override
+ {
+ d_feContext->d_ticketKeys.rotateTicketsKey(now);
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ void loadTicketsKeys(const std::string& keyFile) override final
+ {
+ d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ size_t getTicketsKeysCount() override
+ {
+ return d_feContext->d_ticketKeys.getKeysCount();
+ }
+
+private:
+ std::shared_ptr<OpenSSLFrontendContext> d_feContext;
+};
+
+#endif /* HAVE_LIBSSL */
+
+#ifdef HAVE_GNUTLS
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
+
+static void safe_memory_lock(void* data, size_t size)
+{
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(data, size);
+#endif
+}
+
+static void safe_memory_release(void* data, size_t size)
+{
+#ifdef HAVE_LIBSODIUM
+ sodium_munlock(data, size);
+#elif defined(HAVE_EXPLICIT_BZERO)
+ explicit_bzero(data, size);
+#elif defined(HAVE_EXPLICIT_MEMSET)
+ explicit_memset(data, 0, size);
+#elif defined(HAVE_GNUTLS_MEMSET)
+ gnutls_memset(data, 0, size);
+#else
+ /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
+ volatile unsigned int volatile_zero_idx = 0;
+ volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
+
+ if (size == 0)
+ return;
+
+ do {
+ memset(data, 0, size);
+ } while (p[volatile_zero_idx] != 0);
+#endif
+}
+
+class GnuTLSTicketsKey
+{
+public:
+ GnuTLSTicketsKey()
+ {
+ if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error generating tickets key for TLS context");
+ }
+
+ safe_memory_lock(d_key.data, d_key.size);
+ }
+
+ GnuTLSTicketsKey(const std::string& keyFile)
+ {
+ /* to be sure we are loading the correct amount of data, which
+ may change between versions, let's generate a correct key first */
+ if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
+ }
+
+ safe_memory_lock(d_key.data, d_key.size);
+
+ try {
+ ifstream file(keyFile);
+ file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
+
+ if (file.fail()) {
+ file.close();
+ throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
+ }
+
+ file.close();
+ }
+ catch (const std::exception& e) {
+ safe_memory_release(d_key.data, d_key.size);
+ gnutls_free(d_key.data);
+ d_key.data = nullptr;
+ throw;
+ }
+ }
+
+ ~GnuTLSTicketsKey()
+ {
+ if (d_key.data != nullptr && d_key.size > 0) {
+ safe_memory_release(d_key.data, d_key.size);
+ }
+ gnutls_free(d_key.data);
+ d_key.data = nullptr;
+ }
+ const gnutls_datum_t& getKey() const
+ {
+ return d_key;
+ }
+
+private:
+ gnutls_datum_t d_key{nullptr, 0};
+};
+
+class GnuTLSConnection: public TLSConnection
+{
+public:
+
+ GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
+ {
+ unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
+#ifdef GNUTLS_NO_SIGNAL
+ sslOptions |= GNUTLS_NO_SIGNAL;
+#endif
+
+ d_socket = socket;
+
+ gnutls_session_t conn;
+ if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error creating TLS connection");
+ }
+
+ d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
+ conn = nullptr;
+
+ if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error setting certificate and key to TLS connection");
+ }
+
+ if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error setting ciphers to TLS connection");
+ }
+
+ if (enableTickets && d_ticketsKey) {
+ const gnutls_datum_t& key = d_ticketsKey->getKey();
+ if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error setting the tickets key to TLS connection");
+ }
+ }
+
+ gnutls_transport_set_int(d_conn.get(), d_socket);
+
+ /* timeouts are in milliseconds */
+ gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
+ gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
+ }
+
+ void doHandshake() override
+ {
+ int ret = 0;
+ do {
+ ret = gnutls_handshake(d_conn.get());
+ if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+ throw std::runtime_error("Error accepting a new connection");
+ }
+ }
+ while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
+ }
+
+ IOState tryHandshake() override
+ {
+ int ret = 0;
+
+ do {
+ ret = gnutls_handshake(d_conn.get());
+ if (ret == GNUTLS_E_SUCCESS) {
+ return IOState::Done;
+ }
+ else if (ret == GNUTLS_E_AGAIN) {
+ return IOState::NeedRead;
+ }
+ else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+ throw std::runtime_error("Error accepting a new connection");
+ }
+ } while (ret == GNUTLS_E_INTERRUPTED);
+
+ throw std::runtime_error("Error accepting a new connection");
+ }
+
+ IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+ {
+ do {
+ ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
+ if (res == 0) {
+ throw std::runtime_error("Error writing to TLS connection");
+ }
+ else if (res > 0) {
+ pos += static_cast<size_t>(res);
+ }
+ else if (res < 0) {
+ if (gnutls_error_is_fatal(res)) {
+ throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
+ }
+ else if (res == GNUTLS_E_AGAIN) {
+ return IOState::NeedWrite;
+ }
+ warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+ }
+ }
+ while (pos < toWrite);
+ return IOState::Done;
+ }
+
+ IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
+ {
+ do {
+ ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
+ if (res == 0) {
+ throw std::runtime_error("Error reading from TLS connection");
+ }
+ else if (res > 0) {
+ pos += static_cast<size_t>(res);
+ }
+ else if (res < 0) {
+ if (gnutls_error_is_fatal(res)) {
+ throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
+ }
+ else if (res == GNUTLS_E_AGAIN) {
+ return IOState::NeedRead;
+ }
+ warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+ }
+ }
+ while (pos < toRead);
+ return IOState::Done;
+ }
+
+ size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
+ {
+ size_t got = 0;
+ time_t start = 0;
+ unsigned int remainingTime = totalTimeout;
+ if (totalTimeout) {
+ start = time(nullptr);
+ }
+
+ do {
+ ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
+ if (res == 0) {
+ throw std::runtime_error("Error reading from TLS connection");
+ }
+ else if (res > 0) {
+ got += static_cast<size_t>(res);
+ }
+ else if (res < 0) {
+ if (gnutls_error_is_fatal(res)) {
+ throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
+ }
+ else if (res == GNUTLS_E_AGAIN) {
+ int result = waitForData(d_socket, readTimeout);
+ if (result <= 0) {
+ throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
+ }
+ }
+ else {
+ vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
+ }
+ }
+
+ if (totalTimeout) {
+ time_t now = time(nullptr);
+ unsigned int elapsed = now - start;
+ if (now < start || elapsed >= remainingTime) {
+ throw runtime_error("Timeout while reading data");
+ }
+ start = now;
+ remainingTime -= elapsed;
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
+ {
+ size_t got = 0;
+
+ do {
+ ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
+ if (res == 0) {
+ throw std::runtime_error("Error writing to TLS connection");
+ }
+ else if (res > 0) {
+ got += static_cast<size_t>(res);
+ }
+ else if (res < 0) {
+ if (gnutls_error_is_fatal(res)) {
+ throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
+ }
+ else if (res == GNUTLS_E_AGAIN) {
+ int result = waitForRWData(d_socket, false, writeTimeout, 0);
+ if (result <= 0) {
+ throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
+ }
+ }
+ else {
+ vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+ }
+ }
+ }
+ while (got < bufferSize);
+
+ return got;
+ }
+
+ bool hasBufferedData() const override
+ {
+ if (d_conn) {
+ return gnutls_record_check_pending(d_conn.get()) > 0;
+ }
+
+ return false;
+ }
+
+ std::string getServerNameIndication() const override
+ {
+ if (d_conn) {
+ unsigned int type;
+ size_t name_len = 256;
+ std::string sni;
+ sni.resize(name_len);
+
+ int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
+ if (res == GNUTLS_E_SUCCESS) {
+ sni.resize(name_len);
+ return sni;
+ }
+ }
+ return std::string();
+ }
+
+ LibsslTLSVersion getTLSVersion() const override
+ {
+ auto proto = gnutls_protocol_get_version(d_conn.get());
+ switch (proto) {
+ case GNUTLS_TLS1_0:
+ return LibsslTLSVersion::TLS10;
+ case GNUTLS_TLS1_1:
+ return LibsslTLSVersion::TLS11;
+ case GNUTLS_TLS1_2:
+ return LibsslTLSVersion::TLS12;
+#if GNUTLS_VERSION_NUMBER >= 0x030603
+ case GNUTLS_TLS1_3:
+ return LibsslTLSVersion::TLS13;
+#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
+ default:
+ return LibsslTLSVersion::Unknown;
+ }
+ }
+
+ bool hasSessionBeenResumed() const override
+ {
+ if (d_conn) {
+ return gnutls_session_is_resumed(d_conn.get()) != 0;
+ }
+ return false;
+ }
+
+ void close() override
+ {
+ if (d_conn) {
+ gnutls_bye(d_conn.get(), GNUTLS_SHUT_WR);
+ }
+ }
+
+private:
+ std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
+ std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
+};
+
+class GnuTLSIOCtx: public TLSCtx
+{
+public:
+ GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
+ {
+ int rc = 0;
+ d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
+
+ gnutls_certificate_credentials_t creds;
+ rc = gnutls_certificate_allocate_credentials(&creds);
+ if (rc != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
+ }
+
+ d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+ creds = nullptr;
+
+ for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
+ rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM);
+ if (rc != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
+ }
+ }
+
+ size_t count = 0;
+ for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
+ rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
+ if (rc != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
+ }
+ ++count;
+ }
+
+#if GNUTLS_VERSION_NUMBER >= 0x030600
+ rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
+ if (rc != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
+ }
+#endif
+
+ rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
+ if (rc != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
+ }
+
+ try {
+ if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
+ handleTicketsKeyRotation(time(nullptr));
+ }
+ else {
+ GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
+ }
+ }
+ catch(const std::runtime_error& e) {
+ throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
+ }
+ }
+
+ virtual ~GnuTLSIOCtx() override
+ {
+ d_creds.reset();
+
+ if (d_priorityCache) {
+ gnutls_priority_deinit(d_priorityCache);
+ }
+ }
+
+ std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
+ {
+ handleTicketsKeyRotation(now);
+
+ std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
+ {
+ ReadLock rl(&d_lock);
+ ticketsKey = d_ticketsKey;
+ }
+
+ return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
+ }
+
+ void rotateTicketsKey(time_t now) override
+ {
+ if (!d_enableTickets) {
+ return;
+ }
+
+ auto newKey = std::make_shared<GnuTLSTicketsKey>();
+
+ {
+ WriteLock wl(&d_lock);
+ d_ticketsKey = newKey;
+ }
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ void loadTicketsKeys(const std::string& file) override final
+ {
+ if (!d_enableTickets) {
+ return;
+ }
+
+ auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
+ {
+ WriteLock wl(&d_lock);
+ d_ticketsKey = newKey;
+ }
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+ }
+
+ size_t getTicketsKeysCount() override
+ {
+ ReadLock rl(&d_lock);
+ return d_ticketsKey != nullptr ? 1 : 0;
+ }
+
+private:
+ std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
+ gnutls_priority_t d_priorityCache{nullptr};
+ std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
+ ReadWriteLock d_lock;
+ bool d_enableTickets{true};
+};
+
+#endif /* HAVE_GNUTLS */
+
+#endif /* HAVE_DNS_OVER_TLS */
+
+bool TLSFrontend::setupTLS()
+{
+#ifdef HAVE_DNS_OVER_TLS
+ /* get the "best" available provider */
+ if (!d_provider.empty()) {
+#ifdef HAVE_GNUTLS
+ if (d_provider == "gnutls") {
+ d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
+ return true;
+ }
+#endif /* HAVE_GNUTLS */
+#ifdef HAVE_LIBSSL
+ if (d_provider == "openssl") {
+ d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+ return true;
+ }
+#endif /* HAVE_LIBSSL */
+ }
+#ifdef HAVE_LIBSSL
+ d_ctx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+#else /* HAVE_LIBSSL */
+#ifdef HAVE_GNUTLS
+ d_ctx = std::make_shared<GnuTLSIOCtx>(*this);
+#endif /* HAVE_GNUTLS */
+#endif /* HAVE_LIBSSL */
+
+#endif /* HAVE_DNS_OVER_TLS */
+ return true;
+}