4 #include "circular_buffer.hh"
8 #include "tcpiohandler.hh"
12 #endif /* HAVE_LIBSODIUM */
14 #ifdef HAVE_DNS_OVER_TLS
16 #include <openssl/conf.h>
17 #include <openssl/err.h>
18 #include <openssl/rand.h>
19 #include <openssl/ssl.h>
23 /* From rfc5077 Section 4. Recommended Ticket Construction */
24 #define TLS_TICKETS_KEY_NAME_SIZE (16)
27 #define TLS_TICKETS_CIPHER_KEY_SIZE (32)
28 #define TLS_TICKETS_CIPHER_ALGO (EVP_aes_256_cbc)
31 #define TLS_TICKETS_MAC_KEY_SIZE (32)
32 #define TLS_TICKETS_MAC_ALGO (EVP_sha256)
34 static int s_ticketsKeyIndex
{-1};
36 class OpenSSLTLSTicketKey
41 if (RAND_bytes(d_name
, sizeof(d_name
)) != 1) {
42 throw std::runtime_error("Error while generating the name of the OpenSSL TLS ticket key");
45 if (RAND_bytes(d_cipherKey
, sizeof(d_cipherKey
)) != 1) {
46 throw std::runtime_error("Error while generating the cipher key of the OpenSSL TLS ticket key");
49 if (RAND_bytes(d_hmacKey
, sizeof(d_hmacKey
)) != 1) {
50 throw std::runtime_error("Error while generating the HMAC key of the OpenSSL TLS ticket key");
53 sodium_mlock(d_name
, sizeof(d_name
));
54 sodium_mlock(d_cipherKey
, sizeof(d_cipherKey
));
55 sodium_mlock(d_hmacKey
, sizeof(d_hmacKey
));
56 #endif /* HAVE_LIBSODIUM */
59 OpenSSLTLSTicketKey(ifstream
& file
)
61 file
.read(reinterpret_cast<char*>(d_name
), sizeof(d_name
));
62 file
.read(reinterpret_cast<char*>(d_cipherKey
), sizeof(d_cipherKey
));
63 file
.read(reinterpret_cast<char*>(d_hmacKey
), sizeof(d_hmacKey
));
66 throw std::runtime_error("Unable to load a ticket key from the OpenSSL tickets key file");
69 sodium_mlock(d_name
, sizeof(d_name
));
70 sodium_mlock(d_cipherKey
, sizeof(d_cipherKey
));
71 sodium_mlock(d_hmacKey
, sizeof(d_hmacKey
));
72 #endif /* HAVE_LIBSODIUM */
75 ~OpenSSLTLSTicketKey()
78 sodium_munlock(d_name
, sizeof(d_name
));
79 sodium_munlock(d_cipherKey
, sizeof(d_cipherKey
));
80 sodium_munlock(d_hmacKey
, sizeof(d_hmacKey
));
82 OPENSSL_cleanse(d_name
, sizeof(d_name
));
83 OPENSSL_cleanse(d_cipherKey
, sizeof(d_cipherKey
));
84 OPENSSL_cleanse(d_hmacKey
, sizeof(d_hmacKey
));
85 #endif /* HAVE_LIBSODIUM */
88 bool nameMatches(const unsigned char name
[TLS_TICKETS_KEY_NAME_SIZE
]) const
90 return (memcmp(d_name
, name
, sizeof(d_name
)) == 0);
93 int encrypt(unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char *iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
) const
95 memcpy(keyName
, d_name
, sizeof(d_name
));
97 if (RAND_bytes(iv
, EVP_MAX_IV_LENGTH
) != 1) {
101 if (EVP_EncryptInit_ex(ectx
, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey
, iv
) != 1) {
105 if (HMAC_Init_ex(hctx
, d_hmacKey
, sizeof(d_hmacKey
), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
112 bool decrypt(const unsigned char* iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
) const
114 if (HMAC_Init_ex(hctx
, d_hmacKey
, sizeof(d_hmacKey
), TLS_TICKETS_MAC_ALGO(), nullptr) != 1) {
118 if (EVP_DecryptInit_ex(ectx
, TLS_TICKETS_CIPHER_ALGO(), nullptr, d_cipherKey
, iv
) != 1) {
126 unsigned char d_name
[TLS_TICKETS_KEY_NAME_SIZE
];
127 unsigned char d_cipherKey
[TLS_TICKETS_CIPHER_KEY_SIZE
];
128 unsigned char d_hmacKey
[TLS_TICKETS_MAC_KEY_SIZE
];
131 class OpenSSLTLSTicketKeysRing
134 OpenSSLTLSTicketKeysRing(size_t capacity
)
136 pthread_rwlock_init(&d_lock
, nullptr);
137 d_ticketKeys
.set_capacity(capacity
);
140 ~OpenSSLTLSTicketKeysRing()
142 pthread_rwlock_destroy(&d_lock
);
145 void addKey(std::shared_ptr
<OpenSSLTLSTicketKey
> newKey
)
147 WriteLock
wl(&d_lock
);
148 d_ticketKeys
.push_back(newKey
);
151 std::shared_ptr
<OpenSSLTLSTicketKey
> getEncryptionKey()
153 ReadLock
rl(&d_lock
);
154 return d_ticketKeys
.front();
157 std::shared_ptr
<OpenSSLTLSTicketKey
> getDecryptionKey(unsigned char name
[TLS_TICKETS_KEY_NAME_SIZE
], bool& activeKey
)
159 ReadLock
rl(&d_lock
);
160 for (auto& key
: d_ticketKeys
) {
161 if (key
->nameMatches(name
)) {
162 activeKey
= (key
== d_ticketKeys
.front());
169 size_t getKeysCount()
171 ReadLock
rl(&d_lock
);
172 return d_ticketKeys
.size();
176 boost::circular_buffer
<std::shared_ptr
<OpenSSLTLSTicketKey
> > d_ticketKeys
;
177 pthread_rwlock_t d_lock
;
180 class OpenSSLTLSConnection
: public TLSConnection
183 OpenSSLTLSConnection(int socket
, unsigned int timeout
, SSL_CTX
* tlsCtx
): d_conn(std::unique_ptr
<SSL
, void(*)(SSL
*)>(SSL_new(tlsCtx
), SSL_free
)), d_timeout(timeout
)
188 vinfolog("Error creating TLS object");
190 ERR_print_errors_fp(stderr
);
192 throw std::runtime_error("Error creating TLS object");
195 if (!SSL_set_fd(d_conn
.get(), d_socket
)) {
196 throw std::runtime_error("Error assigning socket");
200 IOState
convertIORequestToIOState(int res
) const
202 int error
= SSL_get_error(d_conn
.get(), res
);
203 if (error
== SSL_ERROR_WANT_READ
) {
204 return IOState::NeedRead
;
206 else if (error
== SSL_ERROR_WANT_WRITE
) {
207 return IOState::NeedWrite
;
209 else if (error
== SSL_ERROR_SYSCALL
) {
210 throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno
)));
213 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error
));
217 void handleIORequest(int res
, unsigned int timeout
)
219 auto state
= convertIORequestToIOState(res
);
220 if (state
== IOState::NeedRead
) {
221 res
= waitForData(d_socket
, timeout
);
223 throw std::runtime_error("Timeout while reading from TLS connection");
226 throw std::runtime_error("Error waiting to read from TLS connection");
229 else if (state
== IOState::NeedWrite
) {
230 res
= waitForRWData(d_socket
, false, timeout
, 0);
232 throw std::runtime_error("Timeout while writing to TLS connection");
235 throw std::runtime_error("Error waiting to write to TLS connection");
240 IOState
tryHandshake() override
242 int res
= SSL_accept(d_conn
.get());
244 return IOState::Done
;
247 return convertIORequestToIOState(res
);
250 throw std::runtime_error("Error accepting TLS connection");
253 void doHandshake() override
257 res
= SSL_accept(d_conn
.get());
259 handleIORequest(res
, d_timeout
);
265 throw std::runtime_error("Error accepting TLS connection");
269 IOState
tryWrite(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toWrite
) override
272 int res
= SSL_write(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), static_cast<int>(toWrite
- pos
));
274 return convertIORequestToIOState(res
);
277 pos
+= static_cast<size_t>(res
);
280 while (pos
< toWrite
);
281 return IOState::Done
;
284 IOState
tryRead(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
) override
287 int res
= SSL_read(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), static_cast<int>(toRead
- pos
));
289 return convertIORequestToIOState(res
);
292 pos
+= static_cast<size_t>(res
);
295 while (pos
< toRead
);
296 return IOState::Done
;
299 size_t read(void* buffer
, size_t bufferSize
, unsigned int readTimeout
, unsigned int totalTimeout
) override
303 unsigned int remainingTime
= totalTimeout
;
305 start
= time(nullptr);
309 int res
= SSL_read(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
311 handleIORequest(res
, readTimeout
);
314 got
+= static_cast<size_t>(res
);
318 time_t now
= time(nullptr);
319 unsigned int elapsed
= now
- start
;
320 if (now
< start
|| elapsed
>= remainingTime
) {
321 throw runtime_error("Timeout while reading data");
324 remainingTime
-= elapsed
;
327 while (got
< bufferSize
);
332 size_t write(const void* buffer
, size_t bufferSize
, unsigned int writeTimeout
) override
336 int res
= SSL_write(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
338 handleIORequest(res
, writeTimeout
);
341 got
+= static_cast<size_t>(res
);
344 while (got
< bufferSize
);
349 void close() override
352 SSL_shutdown(d_conn
.get());
356 std::string
getServerNameIndication() override
359 const char* value
= SSL_get_servername(d_conn
.get(), TLSEXT_NAMETYPE_host_name
);
361 return std::string(value
);
364 return std::string();
368 std::unique_ptr
<SSL
, void(*)(SSL
*)> d_conn
;
369 unsigned int d_timeout
;
372 class OpenSSLTLSIOCtx
: public TLSCtx
375 OpenSSLTLSIOCtx(const TLSFrontend
& fe
): d_ticketKeys(fe
.d_numberOfTicketsKeys
), d_tlsCtx(std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)>(nullptr, SSL_CTX_free
))
377 d_ticketsKeyRotationDelay
= fe
.d_ticketsKeyRotationDelay
;
382 SSL_OP_NO_COMPRESSION
|
383 SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
|
384 SSL_OP_SINGLE_DH_USE
|
385 SSL_OP_SINGLE_ECDH_USE
|
386 SSL_OP_CIPHER_SERVER_PREFERENCE
;
388 if (!fe
.d_enableTickets
) {
389 sslOptions
|= SSL_OP_NO_TICKET
;
392 if (s_users
.fetch_add(1) == 0) {
393 registerOpenSSLUser();
395 s_ticketsKeyIndex
= SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
397 if (s_ticketsKeyIndex
== -1) {
398 throw std::runtime_error("Error getting an index for tickets key");
402 d_tlsCtx
= std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)>(SSL_CTX_new(SSLv23_server_method()), SSL_CTX_free
);
404 ERR_print_errors_fp(stderr
);
405 throw std::runtime_error("Error creating TLS context on " + fe
.d_addr
.toStringWithPort());
408 /* use our own ticket keys handler so we can rotate them */
409 SSL_CTX_set_tlsext_ticket_key_cb(d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ticketKeyCb
);
410 SSL_CTX_set_ex_data(d_tlsCtx
.get(), s_ticketsKeyIndex
, this);
411 SSL_CTX_set_options(d_tlsCtx
.get(), sslOptions
);
412 #if defined(SSL_CTX_set_ecdh_auto)
413 SSL_CTX_set_ecdh_auto(d_tlsCtx
.get(), 1);
415 if (fe
.d_maxStoredSessions
== 0) {
416 /* disable stored sessions entirely */
417 SSL_CTX_set_session_cache_mode(d_tlsCtx
.get(), SSL_SESS_CACHE_OFF
);
420 /* use the internal built-in cache to store sessions */
421 SSL_CTX_set_session_cache_mode(d_tlsCtx
.get(), SSL_SESS_CACHE_SERVER
);
422 SSL_CTX_sess_set_cache_size(d_tlsCtx
.get(), fe
.d_maxStoredSessions
);
425 std::vector
<int> keyTypes
;
426 for (const auto& pair
: fe
.d_certKeyPairs
) {
427 if (SSL_CTX_use_certificate_chain_file(d_tlsCtx
.get(), pair
.first
.c_str()) != 1) {
428 ERR_print_errors_fp(stderr
);
429 throw std::runtime_error("Error loading certificate from " + pair
.first
+ " for the TLS context on " + fe
.d_addr
.toStringWithPort());
431 if (SSL_CTX_use_PrivateKey_file(d_tlsCtx
.get(), pair
.second
.c_str(), SSL_FILETYPE_PEM
) != 1) {
432 ERR_print_errors_fp(stderr
);
433 throw std::runtime_error("Error loading key from " + pair
.second
+ " for the TLS context on " + fe
.d_addr
.toStringWithPort());
435 if (SSL_CTX_check_private_key(d_tlsCtx
.get()) != 1) {
436 ERR_print_errors_fp(stderr
);
437 throw std::runtime_error("Key from '" + pair
.second
+ "' does not match the certificate from '" + pair
.first
+ "' for the TLS context on " + fe
.d_addr
.toStringWithPort());
440 /* store the type of the new key, we might need it later to select the right OCSP stapling response */
441 keyTypes
.push_back(libssl_get_last_key_type(d_tlsCtx
));
444 if (!fe
.d_ocspFiles
.empty()) {
446 d_ocspResponses
= libssl_load_ocsp_responses(fe
.d_ocspFiles
, keyTypes
);
448 catch(const std::exception
& e
) {
449 throw std::runtime_error("Error loading responses for the TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + e
.what());
452 SSL_CTX_set_tlsext_status_cb(d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ocspStaplingCb
);
453 SSL_CTX_set_tlsext_status_arg(d_tlsCtx
.get(), &d_ocspResponses
);
456 if (!fe
.d_ciphers
.empty()) {
457 if (SSL_CTX_set_cipher_list(d_tlsCtx
.get(), fe
.d_ciphers
.c_str()) != 1) {
458 ERR_print_errors_fp(stderr
);
459 throw std::runtime_error("Error setting the cipher list to '" + fe
.d_ciphers
+ "' for the TLS context on " + fe
.d_addr
.toStringWithPort());
463 #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
464 if (!fe
.d_ciphers13
.empty()) {
465 if (SSL_CTX_set_ciphersuites(d_tlsCtx
.get(), fe
.d_ciphers13
.c_str()) != 1) {
466 ERR_print_errors_fp(stderr
);
467 throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + fe
.d_ciphers13
+ "' for the TLS context on " + fe
.d_addr
.toStringWithPort());
470 #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
473 if (fe
.d_ticketKeyFile
.empty()) {
474 handleTicketsKeyRotation(time(nullptr));
477 loadTicketsKeys(fe
.d_ticketKeyFile
);
480 catch (const std::exception
& e
) {
485 virtual ~OpenSSLTLSIOCtx() override
489 if (s_users
.fetch_sub(1) == 1) {
490 unregisterOpenSSLUser();
494 static int ticketKeyCb(SSL
*s
, unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char *iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
, int enc
)
496 SSL_CTX
* sslCtx
= SSL_get_SSL_CTX(s
);
497 if (sslCtx
== nullptr) {
501 OpenSSLTLSIOCtx
* ctx
= reinterpret_cast<OpenSSLTLSIOCtx
*>(SSL_CTX_get_ex_data(sslCtx
, s_ticketsKeyIndex
));
502 if (ctx
== nullptr) {
507 const auto key
= ctx
->d_ticketKeys
.getEncryptionKey();
508 if (key
== nullptr) {
512 return key
->encrypt(keyName
, iv
, ectx
, hctx
);
515 bool activeEncryptionKey
= false;
517 const auto key
= ctx
->d_ticketKeys
.getDecryptionKey(keyName
, activeEncryptionKey
);
518 if (key
== nullptr) {
519 /* we don't know this key, just create a new ticket */
523 if (key
->decrypt(iv
, ectx
, hctx
) == false) {
527 if (!activeEncryptionKey
) {
528 /* this key is not active, please encrypt the ticket content with the currently active one */
535 static int ocspStaplingCb(SSL
* ssl
, void* arg
)
537 if (ssl
== nullptr || arg
== nullptr) {
538 return SSL_TLSEXT_ERR_NOACK
;
540 const auto ocspMap
= reinterpret_cast<std::map
<int, std::string
>*>(arg
);
541 return libssl_ocsp_stapling_callback(ssl
, *ocspMap
);
544 std::unique_ptr
<TLSConnection
> getConnection(int socket
, unsigned int timeout
, time_t now
) override
546 handleTicketsKeyRotation(now
);
548 return std::unique_ptr
<OpenSSLTLSConnection
>(new OpenSSLTLSConnection(socket
, timeout
, d_tlsCtx
.get()));
551 void rotateTicketsKey(time_t now
) override
553 auto newKey
= std::make_shared
<OpenSSLTLSTicketKey
>();
554 d_ticketKeys
.addKey(newKey
);
556 if (d_ticketsKeyRotationDelay
> 0) {
557 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
561 void loadTicketsKeys(const std::string
& keyFile
) override
563 bool keyLoaded
= false;
564 ifstream
file(keyFile
);
567 auto newKey
= std::make_shared
<OpenSSLTLSTicketKey
>(file
);
568 d_ticketKeys
.addKey(newKey
);
571 while (!file
.fail());
573 catch (const std::exception
& e
) {
574 /* if we haven't been able to load at least one key, fail */
580 if (d_ticketsKeyRotationDelay
> 0) {
581 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
587 size_t getTicketsKeysCount() override
589 return d_ticketKeys
.getKeysCount();
593 OpenSSLTLSTicketKeysRing d_ticketKeys
;
594 std::map
<int, std::string
> d_ocspResponses
;
595 std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)> d_tlsCtx
;
596 static std::atomic
<uint64_t> s_users
;
599 std::atomic
<uint64_t> OpenSSLTLSIOCtx::s_users(0);
601 #endif /* HAVE_LIBSSL */
604 #include <gnutls/gnutls.h>
605 #include <gnutls/x509.h>
607 void safe_memory_lock(void* data
, size_t size
)
609 #ifdef HAVE_LIBSODIUM
610 sodium_mlock(data
, size
);
614 void safe_memory_release(void* data
, size_t size
)
616 #ifdef HAVE_LIBSODIUM
617 sodium_munlock(data
, size
);
618 #elif defined(HAVE_EXPLICIT_BZERO)
619 explicit_bzero(data
, size
);
620 #elif defined(HAVE_EXPLICIT_MEMSET)
621 explicit_memset(data
, 0, size
);
622 #elif defined(HAVE_GNUTLS_MEMSET)
623 gnutls_memset(data
, 0, size
);
625 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
626 volatile unsigned int volatile_zero_idx
= 0;
627 volatile unsigned char *p
= reinterpret_cast<volatile unsigned char *>(data
);
633 memset(data
, 0, size
);
634 } while (p
[volatile_zero_idx
] != 0);
638 class GnuTLSTicketsKey
643 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
644 throw std::runtime_error("Error generating tickets key for TLS context");
647 safe_memory_lock(d_key
.data
, d_key
.size
);
650 GnuTLSTicketsKey(const std::string
& keyFile
)
652 /* to be sure we are loading the correct amount of data, which
653 may change between versions, let's generate a correct key first */
654 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
655 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
658 safe_memory_lock(d_key
.data
, d_key
.size
);
661 ifstream
file(keyFile
);
662 file
.read(reinterpret_cast<char*>(d_key
.data
), d_key
.size
);
666 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile
);
671 catch (const std::exception
& e
) {
672 safe_memory_release(d_key
.data
, d_key
.size
);
673 gnutls_free(d_key
.data
);
674 d_key
.data
= nullptr;
681 if (d_key
.data
!= nullptr && d_key
.size
> 0) {
682 safe_memory_release(d_key
.data
, d_key
.size
);
684 gnutls_free(d_key
.data
);
685 d_key
.data
= nullptr;
687 const gnutls_datum_t
& getKey() const
693 gnutls_datum_t d_key
{nullptr, 0};
696 class GnuTLSConnection
: public TLSConnection
700 GnuTLSConnection(int socket
, unsigned int timeout
, const gnutls_certificate_credentials_t creds
, const gnutls_priority_t priorityCache
, std::shared_ptr
<GnuTLSTicketsKey
>& ticketsKey
, bool enableTickets
): d_conn(std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(nullptr, gnutls_deinit
)), d_ticketsKey(ticketsKey
)
702 unsigned int sslOptions
= GNUTLS_SERVER
| GNUTLS_NONBLOCK
;
703 #ifdef GNUTLS_NO_SIGNAL
704 sslOptions
|= GNUTLS_NO_SIGNAL
;
709 gnutls_session_t conn
;
710 if (gnutls_init(&conn
, sslOptions
) != GNUTLS_E_SUCCESS
) {
711 throw std::runtime_error("Error creating TLS connection");
714 d_conn
= std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(conn
, gnutls_deinit
);
717 if (gnutls_credentials_set(d_conn
.get(), GNUTLS_CRD_CERTIFICATE
, creds
) != GNUTLS_E_SUCCESS
) {
718 throw std::runtime_error("Error setting certificate and key to TLS connection");
721 if (gnutls_priority_set(d_conn
.get(), priorityCache
) != GNUTLS_E_SUCCESS
) {
722 throw std::runtime_error("Error setting ciphers to TLS connection");
725 if (enableTickets
&& d_ticketsKey
) {
726 const gnutls_datum_t
& key
= d_ticketsKey
->getKey();
727 if (gnutls_session_ticket_enable_server(d_conn
.get(), &key
) != GNUTLS_E_SUCCESS
) {
728 throw std::runtime_error("Error setting the tickets key to TLS connection");
732 gnutls_transport_set_int(d_conn
.get(), d_socket
);
734 /* timeouts are in milliseconds */
735 gnutls_handshake_set_timeout(d_conn
.get(), timeout
* 1000);
736 gnutls_record_set_timeout(d_conn
.get(), timeout
* 1000);
739 void doHandshake() override
743 ret
= gnutls_handshake(d_conn
.get());
744 if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
745 throw std::runtime_error("Error accepting a new connection");
748 while (ret
< 0 && ret
== GNUTLS_E_INTERRUPTED
);
751 IOState
tryHandshake() override
756 ret
= gnutls_handshake(d_conn
.get());
757 if (ret
== GNUTLS_E_SUCCESS
) {
758 return IOState::Done
;
760 else if (ret
== GNUTLS_E_AGAIN
) {
761 return IOState::NeedRead
;
763 else if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
764 throw std::runtime_error("Error accepting a new connection");
766 } while (ret
== GNUTLS_E_INTERRUPTED
);
768 throw std::runtime_error("Error accepting a new connection");
771 IOState
tryWrite(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toWrite
) override
774 ssize_t res
= gnutls_record_send(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), toWrite
- pos
);
776 throw std::runtime_error("Error writing to TLS connection");
779 pos
+= static_cast<size_t>(res
);
782 if (gnutls_error_is_fatal(res
)) {
783 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
785 else if (res
== GNUTLS_E_AGAIN
) {
786 return IOState::NeedWrite
;
788 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
791 while (pos
< toWrite
);
792 return IOState::Done
;
795 IOState
tryRead(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
) override
798 ssize_t res
= gnutls_record_recv(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), toRead
- pos
);
800 throw std::runtime_error("Error reading from TLS connection");
803 pos
+= static_cast<size_t>(res
);
806 if (gnutls_error_is_fatal(res
)) {
807 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res
)));
809 else if (res
== GNUTLS_E_AGAIN
) {
810 return IOState::NeedRead
;
812 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
815 while (pos
< toRead
);
816 return IOState::Done
;
819 size_t read(void* buffer
, size_t bufferSize
, unsigned int readTimeout
, unsigned int totalTimeout
) override
823 unsigned int remainingTime
= totalTimeout
;
825 start
= time(nullptr);
829 ssize_t res
= gnutls_record_recv(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), bufferSize
- got
);
831 throw std::runtime_error("Error reading from TLS connection");
834 got
+= static_cast<size_t>(res
);
837 if (gnutls_error_is_fatal(res
)) {
838 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res
)));
840 else if (res
== GNUTLS_E_AGAIN
) {
841 int result
= waitForData(d_socket
, readTimeout
);
843 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result
));
847 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res
));
852 time_t now
= time(nullptr);
853 unsigned int elapsed
= now
- start
;
854 if (now
< start
|| elapsed
>= remainingTime
) {
855 throw runtime_error("Timeout while reading data");
858 remainingTime
-= elapsed
;
861 while (got
< bufferSize
);
866 size_t write(const void* buffer
, size_t bufferSize
, unsigned int writeTimeout
) override
871 ssize_t res
= gnutls_record_send(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), bufferSize
- got
);
873 throw std::runtime_error("Error writing to TLS connection");
876 got
+= static_cast<size_t>(res
);
879 if (gnutls_error_is_fatal(res
)) {
880 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
882 else if (res
== GNUTLS_E_AGAIN
) {
883 int result
= waitForRWData(d_socket
, false, writeTimeout
, 0);
885 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result
));
889 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
893 while (got
< bufferSize
);
898 std::string
getServerNameIndication() override
902 size_t name_len
= 256;
904 sni
.resize(name_len
);
906 int res
= gnutls_server_name_get(d_conn
.get(), const_cast<char*>(sni
.c_str()), &name_len
, &type
, 0);
907 if (res
== GNUTLS_E_SUCCESS
) {
908 sni
.resize(name_len
);
912 return std::string();
915 void close() override
918 gnutls_bye(d_conn
.get(), GNUTLS_SHUT_WR
);
923 std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)> d_conn
;
924 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
;
927 class GnuTLSIOCtx
: public TLSCtx
930 GnuTLSIOCtx(const TLSFrontend
& fe
): d_creds(std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)>(nullptr, gnutls_certificate_free_credentials
)), d_enableTickets(fe
.d_enableTickets
)
933 d_ticketsKeyRotationDelay
= fe
.d_ticketsKeyRotationDelay
;
935 gnutls_certificate_credentials_t creds
;
936 rc
= gnutls_certificate_allocate_credentials(&creds
);
937 if (rc
!= GNUTLS_E_SUCCESS
) {
938 throw std::runtime_error("Error allocating credentials for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
941 d_creds
= std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)>(creds
, gnutls_certificate_free_credentials
);
944 for (const auto& pair
: fe
.d_certKeyPairs
) {
945 rc
= gnutls_certificate_set_x509_key_file(d_creds
.get(), pair
.first
.c_str(), pair
.second
.c_str(), GNUTLS_X509_FMT_PEM
);
946 if (rc
!= GNUTLS_E_SUCCESS
) {
947 throw std::runtime_error("Error loading certificate ('" + pair
.first
+ "') and key ('" + pair
.second
+ "') for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
952 for (const auto& file
: fe
.d_ocspFiles
) {
953 rc
= gnutls_certificate_set_ocsp_status_request_file(d_creds
.get(), file
.c_str(), count
);
954 if (rc
!= GNUTLS_E_SUCCESS
) {
955 throw std::runtime_error("Error loading OCSP response from file '" + file
+ "' for certificate ('" + fe
.d_certKeyPairs
.at(count
).first
+ "') and key ('" + fe
.d_certKeyPairs
.at(count
).second
+ "') for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
960 #if GNUTLS_VERSION_NUMBER >= 0x030600
961 rc
= gnutls_certificate_set_known_dh_params(d_creds
.get(), GNUTLS_SEC_PARAM_HIGH
);
962 if (rc
!= GNUTLS_E_SUCCESS
) {
963 throw std::runtime_error("Error setting DH params for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
967 rc
= gnutls_priority_init(&d_priorityCache
, fe
.d_ciphers
.empty() ? "NORMAL" : fe
.d_ciphers
.c_str(), nullptr);
968 if (rc
!= GNUTLS_E_SUCCESS
) {
969 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe
.d_ciphers
+ "' (" + gnutls_strerror(rc
) + ") on " + fe
.d_addr
.toStringWithPort());
972 pthread_rwlock_init(&d_lock
, nullptr);
975 if (fe
.d_ticketKeyFile
.empty()) {
976 handleTicketsKeyRotation(time(nullptr));
979 loadTicketsKeys(fe
.d_ticketKeyFile
);
982 catch(const std::runtime_error
& e
) {
983 pthread_rwlock_destroy(&d_lock
);
984 throw std::runtime_error("Error generating tickets key for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + e
.what());
988 virtual ~GnuTLSIOCtx() override
990 pthread_rwlock_destroy(&d_lock
);
994 if (d_priorityCache
) {
995 gnutls_priority_deinit(d_priorityCache
);
999 std::unique_ptr
<TLSConnection
> getConnection(int socket
, unsigned int timeout
, time_t now
) override
1001 handleTicketsKeyRotation(now
);
1003 std::shared_ptr
<GnuTLSTicketsKey
> ticketsKey
;
1005 ReadLock
rl(&d_lock
);
1006 ticketsKey
= d_ticketsKey
;
1009 return std::unique_ptr
<GnuTLSConnection
>(new GnuTLSConnection(socket
, timeout
, d_creds
.get(), d_priorityCache
, ticketsKey
, d_enableTickets
));
1012 void rotateTicketsKey(time_t now
) override
1014 if (!d_enableTickets
) {
1018 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>();
1021 WriteLock
wl(&d_lock
);
1022 d_ticketsKey
= newKey
;
1025 if (d_ticketsKeyRotationDelay
> 0) {
1026 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
1030 void loadTicketsKeys(const std::string
& file
) override
1032 if (!d_enableTickets
) {
1036 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>(file
);
1038 WriteLock
wl(&d_lock
);
1039 d_ticketsKey
= newKey
;
1042 if (d_ticketsKeyRotationDelay
> 0) {
1043 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
1047 size_t getTicketsKeysCount() override
1049 ReadLock
rl(&d_lock
);
1050 return d_ticketsKey
!= nullptr ? 1 : 0;
1054 std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)> d_creds
;
1055 gnutls_priority_t d_priorityCache
{nullptr};
1056 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
{nullptr};
1057 pthread_rwlock_t d_lock
;
1058 bool d_enableTickets
{true};
1061 #endif /* HAVE_GNUTLS */
1063 #endif /* HAVE_DNS_OVER_TLS */
1065 bool TLSFrontend::setupTLS()
1067 #ifdef HAVE_DNS_OVER_TLS
1068 /* get the "best" available provider */
1069 if (!d_provider
.empty()) {
1071 if (d_provider
== "gnutls") {
1072 d_ctx
= std::make_shared
<GnuTLSIOCtx
>(*this);
1075 #endif /* HAVE_GNUTLS */
1077 if (d_provider
== "openssl") {
1078 d_ctx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
1081 #endif /* HAVE_LIBSSL */
1084 d_ctx
= std::make_shared
<GnuTLSIOCtx
>(*this);
1085 #else /* HAVE_GNUTLS */
1087 d_ctx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
1088 #endif /* HAVE_LIBSSL */
1089 #endif /* HAVE_GNUTLS */
1091 #endif /* HAVE_DNS_OVER_TLS */