7 #include "tcpiohandler.hh"
11 #endif /* HAVE_LIBSODIUM */
13 #ifdef HAVE_DNS_OVER_TLS
15 #include <openssl/conf.h>
16 #include <openssl/err.h>
17 #include <openssl/rand.h>
18 #include <openssl/ssl.h>
20 #include <boost/circular_buffer.hpp>
22 #if (OPENSSL_VERSION_NUMBER < 0x1010000fL || defined LIBRESSL_VERSION_NUMBER)
23 /* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
24 static pthread_mutex_t
*openssllocks
{nullptr};
27 static void openssl_pthreads_locking_callback(int mode
, int type
, const char *file
, int line
)
29 if (mode
& CRYPTO_LOCK
) {
30 pthread_mutex_lock(&(openssllocks
[type
]));
33 pthread_mutex_unlock(&(openssllocks
[type
]));
37 static unsigned long openssl_pthreads_id_callback()
39 return (unsigned long)pthread_self();
43 static void openssl_thread_setup()
45 openssllocks
= (pthread_mutex_t
*)OPENSSL_malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t
));
47 for (int i
= 0; i
< CRYPTO_num_locks(); i
++)
48 pthread_mutex_init(&(openssllocks
[i
]), NULL
);
50 CRYPTO_set_id_callback(openssl_pthreads_id_callback
);
51 CRYPTO_set_locking_callback(openssl_pthreads_locking_callback
);
54 static void openssl_thread_cleanup()
56 CRYPTO_set_locking_callback(NULL
);
58 for (int i
=0; i
<CRYPTO_num_locks(); i
++) {
59 pthread_mutex_destroy(&(openssllocks
[i
]));
62 OPENSSL_free(openssllocks
);
66 static void openssl_thread_setup()
70 static void openssl_thread_cleanup()
73 #endif /* (OPENSSL_VERSION_NUMBER < 0x1010000fL || defined LIBRESSL_VERSION_NUMBER) */
75 /* From rfc5077 Section 4. Recommended Ticket Construction */
76 #define TLS_TICKETS_KEY_NAME_SIZE (16)
79 #define TLS_TICKETS_CIPHER_KEY_SIZE (32)
80 #define TLS_TICKETS_CIPHER_ALGO (EVP_aes_256_cbc)
83 #define TLS_TICKETS_MAC_KEY_SIZE (32)
84 #define TLS_TICKETS_MAC_ALGO (EVP_sha256)
86 static int s_ticketsKeyIndex
{-1};
88 class OpenSSLTLSTicketKey
93 if (RAND_bytes(d_name
, sizeof(d_name
)) != 1) {
94 throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
97 if (RAND_bytes(d_cipherKey
, sizeof(d_cipherKey
)) != 1) {
98 throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
101 if (RAND_bytes(d_hmacKey
, sizeof(d_hmacKey
)) != 1) {
102 throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
104 #ifdef HAVE_LIBSODIUM
105 sodium_mlock(d_name
, sizeof(d_name
));
106 sodium_mlock(d_cipherKey
, sizeof(d_cipherKey
));
107 sodium_mlock(d_hmacKey
, sizeof(d_hmacKey
));
108 #endif /* HAVE_LIBSODIUM */
111 OpenSSLTLSTicketKey(ifstream
& file
)
113 file
.read(reinterpret_cast<char*>(d_name
), sizeof(d_name
));
114 file
.read(reinterpret_cast<char*>(d_cipherKey
), sizeof(d_cipherKey
));
115 file
.read(reinterpret_cast<char*>(d_hmacKey
), sizeof(d_hmacKey
));
118 throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
120 #ifdef HAVE_LIBSODIUM
121 sodium_mlock(d_name
, sizeof(d_name
));
122 sodium_mlock(d_cipherKey
, sizeof(d_cipherKey
));
123 sodium_mlock(d_hmacKey
, sizeof(d_hmacKey
));
124 #endif /* HAVE_LIBSODIUM */
127 ~OpenSSLTLSTicketKey()
129 #ifdef HAVE_LIBSODIUM
130 sodium_munlock(d_name
, sizeof(d_name
));
131 sodium_munlock(d_cipherKey
, sizeof(d_cipherKey
));
132 sodium_munlock(d_hmacKey
, sizeof(d_hmacKey
));
134 OPENSSL_cleanse(d_name
, sizeof(d_name
));
135 OPENSSL_cleanse(d_cipherKey
, sizeof(d_cipherKey
));
136 OPENSSL_cleanse(d_hmacKey
, sizeof(d_hmacKey
));
137 #endif /* HAVE_LIBSODIUM */
140 bool nameMatches(const unsigned char name
[TLS_TICKETS_KEY_NAME_SIZE
]) const
142 return (memcmp(d_name
, name
, sizeof(d_name
)) == 0);
145 int encrypt(unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char *iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
) const
147 memcpy(keyName
, d_name
, sizeof(d_name
));
149 if (RAND_bytes(iv
, EVP_MAX_IV_LENGTH
) != 1) {
153 if (EVP_EncryptInit_ex(ectx
, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey
, iv
) != 1) {
157 if (HMAC_Init_ex(hctx
, d_hmacKey
, sizeof(d_hmacKey
), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
164 bool decrypt(const unsigned char* iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
) const
166 if (HMAC_Init_ex(hctx
, d_hmacKey
, sizeof(d_hmacKey
), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
170 if (EVP_DecryptInit_ex(ectx
, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey
, iv
) != 1) {
178 unsigned char d_name
[TLS_TICKETS_KEY_NAME_SIZE
];
179 unsigned char d_cipherKey
[TLS_TICKETS_CIPHER_KEY_SIZE
];
180 unsigned char d_hmacKey
[TLS_TICKETS_MAC_KEY_SIZE
];
183 class OpenSSLTLSTicketKeysRing
186 OpenSSLTLSTicketKeysRing(size_t capacity
)
188 pthread_rwlock_init(&d_lock
, nullptr);
189 d_ticketKeys
.set_capacity(capacity
);
192 ~OpenSSLTLSTicketKeysRing()
194 pthread_rwlock_destroy(&d_lock
);
197 void addKey(std::shared_ptr
<OpenSSLTLSTicketKey
> newKey
)
199 WriteLock
wl(&d_lock
);
200 d_ticketKeys
.push_back(newKey
);
203 std::shared_ptr
<OpenSSLTLSTicketKey
> getEncryptionKey()
205 ReadLock
rl(&d_lock
);
206 return d_ticketKeys
.front();
209 std::shared_ptr
<OpenSSLTLSTicketKey
> getDecryptionKey(unsigned char name
[TLS_TICKETS_KEY_NAME_SIZE
], bool& activeKey
)
211 ReadLock
rl(&d_lock
);
212 for (auto& key
: d_ticketKeys
) {
213 if (key
->nameMatches(name
)) {
214 activeKey
= (key
== d_ticketKeys
.front());
221 size_t getKeysCount()
223 ReadLock
rl(&d_lock
);
224 return d_ticketKeys
.size();
228 boost::circular_buffer
<std::shared_ptr
<OpenSSLTLSTicketKey
> > d_ticketKeys
;
229 pthread_rwlock_t d_lock
;
232 class OpenSSLTLSConnection
: public TLSConnection
235 OpenSSLTLSConnection(int socket
, unsigned int timeout
, SSL_CTX
* tlsCtx
): d_conn(std::unique_ptr
<SSL
, void(*)(SSL
*)>(SSL_new(tlsCtx
), SSL_free
)), d_timeout(timeout
)
240 vinfolog("Error creating TLS object");
242 ERR_print_errors_fp(stderr
);
244 throw std::runtime_error("Error creating TLS object");
247 if (!SSL_set_fd(d_conn
.get(), d_socket
)) {
248 throw std::runtime_error("Error assigning socket");
252 IOState
convertIORequestToIOState(int res
) const
254 int error
= SSL_get_error(d_conn
.get(), res
);
255 if (error
== SSL_ERROR_WANT_READ
) {
256 return IOState::NeedRead
;
258 else if (error
== SSL_ERROR_WANT_WRITE
) {
259 return IOState::NeedWrite
;
261 else if (error
== SSL_ERROR_SYSCALL
) {
262 throw std::runtime_error("Error while processing TLS connection:" + std::string(strerror(errno
)));
265 throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error
));
269 void handleIORequest(int res
, unsigned int timeout
)
271 auto state
= convertIORequestToIOState(res
);
272 if (state
== IOState::NeedRead
) {
273 res
= waitForData(d_socket
, timeout
);
275 throw std::runtime_error("Error reading from TLS connection");
278 else if (state
== IOState::NeedWrite
) {
279 res
= waitForRWData(d_socket
, false, timeout
, 0);
281 throw std::runtime_error("Error waiting to write to TLS connection");
286 IOState
tryHandshake() override
288 int res
= SSL_accept(d_conn
.get());
290 return IOState::Done
;
293 return convertIORequestToIOState(res
);
296 throw std::runtime_error("Error accepting TLS connection");
299 void doHandshake() override
303 res
= SSL_accept(d_conn
.get());
305 handleIORequest(res
, d_timeout
);
311 throw std::runtime_error("Error accepting TLS connection");
315 IOState
tryWrite(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toWrite
) override
318 int res
= SSL_write(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), static_cast<int>(toWrite
- pos
));
320 throw std::runtime_error("Error writing to TLS connection");
323 return convertIORequestToIOState(res
);
326 pos
+= static_cast<size_t>(res
);
329 while (pos
< toWrite
);
330 return IOState::Done
;
333 IOState
tryRead(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
) override
336 int res
= SSL_read(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), static_cast<int>(toRead
- pos
));
338 throw std::runtime_error("Error reading from TLS connection");
341 return convertIORequestToIOState(res
);
344 pos
+= static_cast<size_t>(res
);
347 while (pos
< toRead
);
348 return IOState::Done
;
351 size_t read(void* buffer
, size_t bufferSize
, unsigned int readTimeout
, unsigned int totalTimeout
) override
355 unsigned int remainingTime
= totalTimeout
;
357 start
= time(nullptr);
361 int res
= SSL_read(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
363 throw std::runtime_error("Error reading from TLS connection");
366 handleIORequest(res
, readTimeout
);
369 got
+= static_cast<size_t>(res
);
373 time_t now
= time(nullptr);
374 unsigned int elapsed
= now
- start
;
375 if (now
< start
|| elapsed
>= remainingTime
) {
376 throw runtime_error("Timeout while reading data");
379 remainingTime
-= elapsed
;
382 while (got
< bufferSize
);
387 size_t write(const void* buffer
, size_t bufferSize
, unsigned int writeTimeout
) override
391 int res
= SSL_write(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
393 throw std::runtime_error("Error writing to TLS connection");
396 handleIORequest(res
, writeTimeout
);
399 got
+= static_cast<size_t>(res
);
402 while (got
< bufferSize
);
406 void close() override
409 SSL_shutdown(d_conn
.get());
414 std::unique_ptr
<SSL
, void(*)(SSL
*)> d_conn
;
415 unsigned int d_timeout
;
418 class OpenSSLTLSIOCtx
: public TLSCtx
421 OpenSSLTLSIOCtx(const TLSFrontend
& fe
): d_ticketKeys(fe
.d_numberOfTicketsKeys
), d_tlsCtx(std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)>(nullptr, SSL_CTX_free
))
423 d_ticketsKeyRotationDelay
= fe
.d_ticketsKeyRotationDelay
;
428 SSL_OP_NO_COMPRESSION
|
429 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
|
430 SSL_OP_SINGLE_DH_USE
|
431 SSL_OP_SINGLE_ECDH_USE
|
432 SSL_OP_CIPHER_SERVER_PREFERENCE
;
434 if (!fe
.d_enableTickets
) {
435 sslOptions
|= SSL_OP_NO_TICKET
;
438 if (s_users
.fetch_add(1) == 0) {
439 ERR_load_crypto_strings();
440 OpenSSL_add_ssl_algorithms();
441 openssl_thread_setup();
443 s_ticketsKeyIndex
= SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
445 if (s_ticketsKeyIndex
== -1) {
446 throw std::runtime_error("Error getting an index for tickets key");
450 d_tlsCtx
= std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free
);
452 ERR_print_errors_fp(stderr
);
453 throw std::runtime_error("Error creating TLS context on " + fe
.d_addr
.toStringWithPort());
456 /* use our own ticket keys handler so we can rotate them */
457 SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ticketKeyCb
);
458 SSL_CTX_set_ex_data(d_tlsCtx
.get(), s_ticketsKeyIndex
, this);
459 SSL_CTX_set_options(d_tlsCtx
.get(), sslOptions
);
460 #if defined(SSL_CTX_set_ecdh_auto)
461 SSL_CTX_set_ecdh_auto(d_tlsCtx
.get(), 1);
463 if (fe
.d_maxStoredSessions
== 0) {
464 /* disable stored sessions entirely */
465 SSL_CTX_set_session_cache_mode(d_tlsCtx
.get(), SSL_SESS_CACHE_OFF
);
468 /* use the internal built-in cache to store sessions */
469 SSL_CTX_set_session_cache_mode(d_tlsCtx
.get(), SSL_SESS_CACHE_SERVER
);
470 SSL_CTX_sess_set_cache_size(d_tlsCtx
.get(), fe
.d_maxStoredSessions
);
473 for (const auto& pair
: fe
.d_certKeyPairs
) {
474 if (SSL_CTX_use_certificate_chain_file(d_tlsCtx
.get(), pair
.first
.c_str()) != 1) {
475 ERR_print_errors_fp(stderr
);
476 throw std::runtime_error("Error loading certificate from " + pair
.first
+ " for the TLS context on " + fe
.d_addr
.toStringWithPort());
478 if (SSL_CTX_use_PrivateKey_file(d_tlsCtx
.get(), pair
.second
.c_str(), SSL_FILETYPE_PEM
) != 1) {
479 ERR_print_errors_fp(stderr
);
480 throw std::runtime_error("Error loading key from " + pair
.second
+ " for the TLS context on " + fe
.d_addr
.toStringWithPort());
484 if (!fe
.d_ciphers
.empty()) {
485 if (SSL_CTX_set_cipher_list(d_tlsCtx
.get(), fe
.d_ciphers
.c_str()) != 1) {
486 ERR_print_errors_fp(stderr
);
487 throw std::runtime_error("Error setting the cipher list to '" + fe
.d_ciphers
+ "' for the TLS context on " + fe
.d_addr
.toStringWithPort());
492 if (fe
.d_ticketKeyFile
.empty()) {
493 handleTicketsKeyRotation(time(nullptr));
496 loadTicketsKeys(fe
.d_ticketKeyFile
);
499 catch (const std::exception
& e
) {
504 virtual ~OpenSSLTLSIOCtx() override
508 if (s_users
.fetch_sub(1) == 1) {
513 CONF_modules_finish();
515 CONF_modules_unload(1);
517 CRYPTO_cleanup_all_ex_data();
518 openssl_thread_cleanup();
522 static int ticketKeyCb(SSL
*s
, unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char *iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
, int enc
)
524 SSL_CTX
* sslCtx
= SSL_get_SSL_CTX(s
);
525 if (sslCtx
== nullptr) {
529 OpenSSLTLSIOCtx
* ctx
= reinterpret_cast<OpenSSLTLSIOCtx
*>(SSL_CTX_get_ex_data(sslCtx
, s_ticketsKeyIndex
));
530 if (ctx
== nullptr) {
535 const auto key
= ctx
->d_ticketKeys
.getEncryptionKey();
536 if (key
== nullptr) {
540 return key
->encrypt(keyName
, iv
, ectx
, hctx
);
543 bool activeEncryptionKey
= false;
545 const auto key
= ctx
->d_ticketKeys
.getDecryptionKey(keyName
, activeEncryptionKey
);
546 if (key
== nullptr) {
547 /* we don't know this key, just create a new ticket */
551 if (key
->decrypt(iv
, ectx
, hctx
) == false) {
555 if (!activeEncryptionKey
) {
556 /* this key is not active, please encrypt the ticket content with the currently active one */
563 std::unique_ptr
<TLSConnection
> getConnection(int socket
, unsigned int timeout
, time_t now
) override
565 handleTicketsKeyRotation(now
);
567 return std::unique_ptr
<OpenSSLTLSConnection
>(new OpenSSLTLSConnection(socket
, timeout
, d_tlsCtx
.get()));
570 void rotateTicketsKey(time_t now
) override
572 auto newKey
= std::make_shared
<OpenSSLTLSTicketKey
>();
573 d_ticketKeys
.addKey(newKey
);
575 if (d_ticketsKeyRotationDelay
> 0) {
576 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
580 void loadTicketsKeys(const std::string
& keyFile
) override
582 bool keyLoaded
= false;
583 ifstream
file(keyFile
);
586 auto newKey
= std::make_shared
<OpenSSLTLSTicketKey
>(file
);
587 d_ticketKeys
.addKey(newKey
);
590 while (!file
.fail());
592 catch (const std::exception
& e
) {
593 /* if we haven't been able to load at least one key, fail */
599 if (d_ticketsKeyRotationDelay
> 0) {
600 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
606 size_t getTicketsKeysCount() override
608 return d_ticketKeys
.getKeysCount();
612 OpenSSLTLSTicketKeysRing d_ticketKeys
;
613 std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)> d_tlsCtx
;
614 static std::atomic
<uint64_t> s_users
;
617 std::atomic
<uint64_t> OpenSSLTLSIOCtx::s_users(0);
619 #endif /* HAVE_LIBSSL */
622 #include <gnutls/gnutls.h>
623 #include <gnutls/x509.h>
625 void safe_memory_lock(void* data
, size_t size
)
627 #ifdef HAVE_LIBSODIUM
628 sodium_mlock(data
, size
);
632 void safe_memory_release(void* data
, size_t size
)
634 #ifdef HAVE_LIBSODIUM
635 sodium_munlock(data
, size
);
636 #elif defined(HAVE_EXPLICIT_BZERO)
637 explicit_bzero(data
, size
);
638 #elif defined(HAVE_EXPLICIT_MEMSET)
639 explicit_memset(data
, 0, size
);
640 #elif defined(HAVE_GNUTLS_MEMSET)
641 gnutls_memset(data
, 0, size
);
643 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
644 volatile unsigned int volatile_zero_idx
= 0;
645 volatile unsigned char *p
= reinterpret_cast<volatile unsigned char *>(data
);
651 memset(data
, 0, size
);
652 } while (p
[volatile_zero_idx
] != 0);
656 class GnuTLSTicketsKey
661 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
662 throw std::runtime_error("Error generating tickets key for TLS context");
665 safe_memory_lock(d_key
.data
, d_key
.size
);
668 GnuTLSTicketsKey(const std::string
& keyFile
)
670 /* to be sure we are loading the correct amount of data, which
671 may change between versions, let's generate a correct key first */
672 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
673 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
676 safe_memory_lock(d_key
.data
, d_key
.size
);
679 ifstream
file(keyFile
);
680 file
.read(reinterpret_cast<char*>(d_key
.data
), d_key
.size
);
684 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile
);
689 catch (const std::exception
& e
) {
690 safe_memory_release(d_key
.data
, d_key
.size
);
691 gnutls_free(d_key
.data
);
692 d_key
.data
= nullptr;
699 if (d_key
.data
!= nullptr && d_key
.size
> 0) {
700 safe_memory_release(d_key
.data
, d_key
.size
);
702 gnutls_free(d_key
.data
);
703 d_key
.data
= nullptr;
705 const gnutls_datum_t
& getKey() const
711 gnutls_datum_t d_key
{nullptr, 0};
714 class GnuTLSConnection
: public TLSConnection
718 GnuTLSConnection(int socket
, unsigned int timeout
, const gnutls_certificate_credentials_t creds
, const gnutls_priority_t priorityCache
, std::shared_ptr
<GnuTLSTicketsKey
>& ticketsKey
, bool enableTickets
): d_conn(std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(nullptr, gnutls_deinit
)), d_ticketsKey(ticketsKey
)
720 unsigned int sslOptions
= GNUTLS_SERVER
| GNUTLS_NONBLOCK
;
721 #ifdef GNUTLS_NO_SIGNAL
722 sslOptions
|= GNUTLS_NO_SIGNAL
;
727 gnutls_session_t conn
;
728 if (gnutls_init(&conn
, sslOptions
) != GNUTLS_E_SUCCESS
) {
729 throw std::runtime_error("Error creating TLS connection");
732 d_conn
= std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(conn
, gnutls_deinit
);
735 if (gnutls_credentials_set(d_conn
.get(), GNUTLS_CRD_CERTIFICATE
, creds
) != GNUTLS_E_SUCCESS
) {
736 throw std::runtime_error("Error setting certificate and key to TLS connection");
739 if (gnutls_priority_set(d_conn
.get(), priorityCache
) != GNUTLS_E_SUCCESS
) {
740 throw std::runtime_error("Error setting ciphers to TLS connection");
743 if (enableTickets
&& d_ticketsKey
) {
744 const gnutls_datum_t
& key
= d_ticketsKey
->getKey();
745 if (gnutls_session_ticket_enable_server(d_conn
.get(), &key
) != GNUTLS_E_SUCCESS
) {
746 throw std::runtime_error("Error setting the tickets key to TLS connection");
750 gnutls_transport_set_int(d_conn
.get(), d_socket
);
752 /* timeouts are in milliseconds */
753 gnutls_handshake_set_timeout(d_conn
.get(), timeout
* 1000);
754 gnutls_record_set_timeout(d_conn
.get(), timeout
* 1000);
757 void doHandshake() override
761 ret
= gnutls_handshake(d_conn
.get());
762 if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
763 throw std::runtime_error("Error accepting a new connection");
766 while (ret
< 0 && ret
== GNUTLS_E_INTERRUPTED
);
769 IOState
tryHandshake() override
774 ret
= gnutls_handshake(d_conn
.get());
775 if (ret
== GNUTLS_E_SUCCESS
) {
776 return IOState::Done
;
778 else if (ret
== GNUTLS_E_AGAIN
) {
779 return IOState::NeedRead
;
781 else if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
782 throw std::runtime_error("Error accepting a new connection");
784 } while (ret
== GNUTLS_E_INTERRUPTED
);
786 throw std::runtime_error("Error accepting a new connection");
789 IOState
tryWrite(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toWrite
) override
792 ssize_t res
= gnutls_record_send(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), toWrite
- pos
);
794 throw std::runtime_error("Error writing to TLS connection");
797 pos
+= static_cast<size_t>(res
);
800 if (gnutls_error_is_fatal(res
)) {
801 throw std::runtime_error("Error writing to TLS connection");
803 else if (res
== GNUTLS_E_AGAIN
) {
804 return IOState::NeedWrite
;
806 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
809 while (pos
< toWrite
);
810 return IOState::Done
;
813 IOState
tryRead(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
) override
816 ssize_t res
= gnutls_record_recv(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), toRead
- pos
);
818 throw std::runtime_error("Error reading from TLS connection");
821 pos
+= static_cast<size_t>(res
);
824 if (gnutls_error_is_fatal(res
)) {
825 throw std::runtime_error("Error reading from TLS connection");
827 else if (res
== GNUTLS_E_AGAIN
) {
828 return IOState::NeedRead
;
830 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
833 while (pos
< toRead
);
834 return IOState::Done
;
837 size_t read(void* buffer
, size_t bufferSize
, unsigned int readTimeout
, unsigned int totalTimeout
) override
841 unsigned int remainingTime
= totalTimeout
;
843 start
= time(nullptr);
847 ssize_t res
= gnutls_record_recv(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), bufferSize
- got
);
849 throw std::runtime_error("Error reading from TLS connection");
852 got
+= static_cast<size_t>(res
);
855 if (gnutls_error_is_fatal(res
)) {
856 throw std::runtime_error("Error reading from TLS connection:" + std::string(gnutls_strerror(res
)));
858 else if (res
== GNUTLS_E_AGAIN
) {
859 int result
= waitForData(d_socket
, readTimeout
);
861 throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result
));
865 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res
));
870 time_t now
= time(nullptr);
871 unsigned int elapsed
= now
- start
;
872 if (now
< start
|| elapsed
>= remainingTime
) {
873 throw runtime_error("Timeout while reading data");
876 remainingTime
-= elapsed
;
879 while (got
< bufferSize
);
884 size_t write(const void* buffer
, size_t bufferSize
, unsigned int writeTimeout
) override
889 ssize_t res
= gnutls_record_send(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), bufferSize
- got
);
891 throw std::runtime_error("Error writing to TLS connection");
894 got
+= static_cast<size_t>(res
);
897 if (gnutls_error_is_fatal(res
)) {
898 throw std::runtime_error("Error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
900 else if (res
== GNUTLS_E_AGAIN
) {
901 int result
= waitForRWData(d_socket
, false, writeTimeout
, 0);
903 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result
));
907 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
911 while (got
< bufferSize
);
916 void close() override
919 gnutls_bye(d_conn
.get(), GNUTLS_SHUT_WR
);
924 std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)> d_conn
;
925 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
;
928 class GnuTLSIOCtx
: public TLSCtx
931 GnuTLSIOCtx(const TLSFrontend
& fe
): d_creds(std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)>(nullptr, gnutls_certificate_free_credentials
)), d_enableTickets(fe
.d_enableTickets
)
934 d_ticketsKeyRotationDelay
= fe
.d_ticketsKeyRotationDelay
;
936 gnutls_certificate_credentials_t creds
;
937 rc
= gnutls_certificate_allocate_credentials(&creds
);
938 if (rc
!= GNUTLS_E_SUCCESS
) {
939 throw std::runtime_error("Error allocating credentials for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
942 d_creds
= std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)>(creds
, gnutls_certificate_free_credentials
);
945 for (const auto& pair
: fe
.d_certKeyPairs
) {
946 rc
= gnutls_certificate_set_x509_key_file(d_creds
.get(), pair
.first
.c_str(), pair
.second
.c_str(), GNUTLS_X509_FMT_PEM
);
947 if (rc
!= GNUTLS_E_SUCCESS
) {
948 throw std::runtime_error("Error loading certificate ('" + pair
.first
+ "') and key ('" + pair
.second
+ "') for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
952 #if GNUTLS_VERSION_NUMBER >= 0x030600
953 rc
= gnutls_certificate_set_known_dh_params(d_creds
.get(), GNUTLS_SEC_PARAM_HIGH
);
954 if (rc
!= GNUTLS_E_SUCCESS
) {
955 throw std::runtime_error("Error setting DH params for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
959 rc
= gnutls_priority_init(&d_priorityCache
, fe
.d_ciphers
.empty() ? "NORMAL" : fe
.d_ciphers
.c_str(), nullptr);
960 if (rc
!= GNUTLS_E_SUCCESS
) {
961 warnlog("Error setting up TLS cipher preferences to %s (%s), skipping.", fe
.d_ciphers
.c_str(), gnutls_strerror(rc
));
964 pthread_rwlock_init(&d_lock
, nullptr);
967 if (fe
.d_ticketKeyFile
.empty()) {
968 handleTicketsKeyRotation(time(nullptr));
971 loadTicketsKeys(fe
.d_ticketKeyFile
);
974 catch(const std::runtime_error
& e
) {
975 pthread_rwlock_destroy(&d_lock
);
976 throw std::runtime_error("Error generating tickets key for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + e
.what());
980 virtual ~GnuTLSIOCtx() override
982 pthread_rwlock_destroy(&d_lock
);
986 if (d_priorityCache
) {
987 gnutls_priority_deinit(d_priorityCache
);
991 std::unique_ptr
<TLSConnection
> getConnection(int socket
, unsigned int timeout
, time_t now
) override
993 handleTicketsKeyRotation(now
);
995 std::shared_ptr
<GnuTLSTicketsKey
> ticketsKey
;
997 ReadLock
rl(&d_lock
);
998 ticketsKey
= d_ticketsKey
;
1001 return std::unique_ptr
<GnuTLSConnection
>(new GnuTLSConnection(socket
, timeout
, d_creds
.get(), d_priorityCache
, ticketsKey
, d_enableTickets
));
1004 void rotateTicketsKey(time_t now
) override
1006 if (!d_enableTickets
) {
1010 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>();
1013 WriteLock
wl(&d_lock
);
1014 d_ticketsKey
= newKey
;
1017 if (d_ticketsKeyRotationDelay
> 0) {
1018 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
1022 void loadTicketsKeys(const std::string
& file
) override
1024 if (!d_enableTickets
) {
1028 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>(file
);
1030 WriteLock
wl(&d_lock
);
1031 d_ticketsKey
= newKey
;
1034 if (d_ticketsKeyRotationDelay
> 0) {
1035 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
1039 size_t getTicketsKeysCount() override
1041 ReadLock
rl(&d_lock
);
1042 return d_ticketsKey
!= nullptr ? 1 : 0;
1046 std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)> d_creds
;
1047 gnutls_priority_t d_priorityCache
{nullptr};
1048 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
{nullptr};
1049 pthread_rwlock_t d_lock
;
1050 bool d_enableTickets
{true};
1053 #endif /* HAVE_GNUTLS */
1055 #endif /* HAVE_DNS_OVER_TLS */
1057 bool TLSFrontend::setupTLS()
1059 #ifdef HAVE_DNS_OVER_TLS
1060 /* get the "best" available provider */
1061 if (!d_provider
.empty()) {
1063 if (d_provider
== "gnutls") {
1064 d_ctx
= std::make_shared
<GnuTLSIOCtx
>(*this);
1067 #endif /* HAVE_GNUTLS */
1069 if (d_provider
== "openssl") {
1070 d_ctx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
1073 #endif /* HAVE_LIBSSL */
1076 d_ctx
= std::make_shared
<GnuTLSIOCtx
>(*this);
1077 #else /* HAVE_GNUTLS */
1079 d_ctx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
1080 #endif /* HAVE_LIBSSL */
1081 #endif /* HAVE_GNUTLS */
1083 #endif /* HAVE_DNS_OVER_TLS */