6 #include "tcpiohandler.hh"
10 #endif /* HAVE_LIBSODIUM */
12 #ifdef HAVE_DNS_OVER_TLS
14 #include <openssl/conf.h>
15 #include <openssl/err.h>
16 #include <openssl/rand.h>
17 #include <openssl/ssl.h>
21 class OpenSSLFrontendContext
24 OpenSSLFrontendContext(const ComboAddress
& addr
, const TLSConfig
& tlsConfig
): d_ticketKeys(tlsConfig
.d_numberOfTicketsKeys
)
26 registerOpenSSLUser();
28 d_tlsCtx
= libssl_init_server_context(tlsConfig
, d_ocspResponses
);
30 ERR_print_errors_fp(stderr
);
31 throw std::runtime_error("Error creating TLS context on " + addr
.toStringWithPort());
39 unregisterOpenSSLUser();
42 OpenSSLTLSTicketKeysRing d_ticketKeys
;
43 std::map
<int, std::string
> d_ocspResponses
;
44 std::unique_ptr
<SSL_CTX
, void(*)(SSL_CTX
*)> d_tlsCtx
{nullptr, SSL_CTX_free
};
45 std::unique_ptr
<FILE, int(*)(FILE*)> d_keyLogFile
{nullptr, fclose
};
48 class OpenSSLTLSConnection
: public TLSConnection
51 OpenSSLTLSConnection(int socket
, unsigned int timeout
, std::shared_ptr
<OpenSSLFrontendContext
> feContext
): d_feContext(feContext
), d_conn(std::unique_ptr
<SSL
, void(*)(SSL
*)>(SSL_new(d_feContext
->d_tlsCtx
.get()), SSL_free
)), d_timeout(timeout
)
55 if (!s_initTLSConnIndex
.test_and_set()) {
56 /* not initialized yet */
57 s_tlsConnIndex
= SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
58 if (s_tlsConnIndex
== -1) {
59 throw std::runtime_error("Error getting an index for TLS connection data");
64 vinfolog("Error creating TLS object");
66 ERR_print_errors_fp(stderr
);
68 throw std::runtime_error("Error creating TLS object");
71 if (!SSL_set_fd(d_conn
.get(), d_socket
)) {
72 throw std::runtime_error("Error assigning socket");
75 SSL_set_ex_data(d_conn
.get(), s_tlsConnIndex
, this);
78 IOState
convertIORequestToIOState(int res
) const
80 int error
= SSL_get_error(d_conn
.get(), res
);
81 if (error
== SSL_ERROR_WANT_READ
) {
82 return IOState::NeedRead
;
84 else if (error
== SSL_ERROR_WANT_WRITE
) {
85 return IOState::NeedWrite
;
87 else if (error
== SSL_ERROR_SYSCALL
) {
88 throw std::runtime_error("Error while processing TLS connection: " + std::string(strerror(errno
)));
91 throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error
));
95 void handleIORequest(int res
, unsigned int timeout
)
97 auto state
= convertIORequestToIOState(res
);
98 if (state
== IOState::NeedRead
) {
99 res
= waitForData(d_socket
, timeout
);
101 throw std::runtime_error("Timeout while reading from TLS connection");
104 throw std::runtime_error("Error waiting to read from TLS connection");
107 else if (state
== IOState::NeedWrite
) {
108 res
= waitForRWData(d_socket
, false, timeout
, 0);
110 throw std::runtime_error("Timeout while writing to TLS connection");
113 throw std::runtime_error("Error waiting to write to TLS connection");
118 IOState
tryHandshake() override
120 int res
= SSL_accept(d_conn
.get());
122 return IOState::Done
;
125 return convertIORequestToIOState(res
);
128 throw std::runtime_error("Error accepting TLS connection");
131 void doHandshake() override
135 res
= SSL_accept(d_conn
.get());
137 handleIORequest(res
, d_timeout
);
143 throw std::runtime_error("Error accepting TLS connection");
147 IOState
tryWrite(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toWrite
) override
150 int res
= SSL_write(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), static_cast<int>(toWrite
- pos
));
152 return convertIORequestToIOState(res
);
155 pos
+= static_cast<size_t>(res
);
158 while (pos
< toWrite
);
159 return IOState::Done
;
162 IOState
tryRead(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
) override
165 int res
= SSL_read(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), static_cast<int>(toRead
- pos
));
167 return convertIORequestToIOState(res
);
170 pos
+= static_cast<size_t>(res
);
173 while (pos
< toRead
);
174 return IOState::Done
;
177 size_t read(void* buffer
, size_t bufferSize
, unsigned int readTimeout
, unsigned int totalTimeout
) override
181 unsigned int remainingTime
= totalTimeout
;
183 start
= time(nullptr);
187 int res
= SSL_read(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
189 handleIORequest(res
, readTimeout
);
192 got
+= static_cast<size_t>(res
);
196 time_t now
= time(nullptr);
197 unsigned int elapsed
= now
- start
;
198 if (now
< start
|| elapsed
>= remainingTime
) {
199 throw runtime_error("Timeout while reading data");
202 remainingTime
-= elapsed
;
205 while (got
< bufferSize
);
210 size_t write(const void* buffer
, size_t bufferSize
, unsigned int writeTimeout
) override
214 int res
= SSL_write(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), static_cast<int>(bufferSize
- got
));
216 handleIORequest(res
, writeTimeout
);
219 got
+= static_cast<size_t>(res
);
222 while (got
< bufferSize
);
227 void close() override
230 SSL_shutdown(d_conn
.get());
234 std::string
getServerNameIndication() const override
237 const char* value
= SSL_get_servername(d_conn
.get(), TLSEXT_NAMETYPE_host_name
);
239 return std::string(value
);
242 return std::string();
245 LibsslTLSVersion
getTLSVersion() const override
247 auto proto
= SSL_version(d_conn
.get());
250 return LibsslTLSVersion::TLS10
;
252 return LibsslTLSVersion::TLS11
;
254 return LibsslTLSVersion::TLS12
;
255 #ifdef TLS1_3_VERSION
257 return LibsslTLSVersion::TLS13
;
258 #endif /* TLS1_3_VERSION */
260 return LibsslTLSVersion::Unknown
;
264 bool hasSessionBeenResumed() const override
267 return SSL_session_reused(d_conn
.get()) != 0;
272 static int s_tlsConnIndex
;
275 static std::atomic_flag s_initTLSConnIndex
;
277 std::shared_ptr
<OpenSSLFrontendContext
> d_feContext
;
278 std::unique_ptr
<SSL
, void(*)(SSL
*)> d_conn
;
279 unsigned int d_timeout
;
282 std::atomic_flag
OpenSSLTLSConnection::s_initTLSConnIndex
= ATOMIC_FLAG_INIT
;
283 int OpenSSLTLSConnection::s_tlsConnIndex
= -1;
285 class OpenSSLTLSIOCtx
: public TLSCtx
288 OpenSSLTLSIOCtx(TLSFrontend
& fe
)
290 d_feContext
= std::make_shared
<OpenSSLFrontendContext
>(fe
.d_addr
, fe
.d_tlsConfig
);
292 d_ticketsKeyRotationDelay
= fe
.d_tlsConfig
.d_ticketsKeyRotationDelay
;
294 if (fe
.d_tlsConfig
.d_enableTickets
&& fe
.d_tlsConfig
.d_numberOfTicketsKeys
> 0) {
295 /* use our own ticket keys handler so we can rotate them */
296 SSL_CTX_set_tlsext_ticket_key_cb(d_feContext
->d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ticketKeyCb
);
297 libssl_set_ticket_key_callback_data(d_feContext
->d_tlsCtx
.get(), d_feContext
.get());
300 if (!d_feContext
->d_ocspResponses
.empty()) {
301 SSL_CTX_set_tlsext_status_cb(d_feContext
->d_tlsCtx
.get(), &OpenSSLTLSIOCtx::ocspStaplingCb
);
302 SSL_CTX_set_tlsext_status_arg(d_feContext
->d_tlsCtx
.get(), &d_feContext
->d_ocspResponses
);
305 libssl_set_error_counters_callback(d_feContext
->d_tlsCtx
, &fe
.d_tlsCounters
);
307 if (!fe
.d_tlsConfig
.d_keyLogFile
.empty()) {
308 d_feContext
->d_keyLogFile
= libssl_set_key_log_file(d_feContext
->d_tlsCtx
, fe
.d_tlsConfig
.d_keyLogFile
);
312 if (fe
.d_tlsConfig
.d_ticketKeyFile
.empty()) {
313 handleTicketsKeyRotation(time(nullptr));
316 OpenSSLTLSIOCtx::loadTicketsKeys(fe
.d_tlsConfig
.d_ticketKeyFile
);
319 catch (const std::exception
& e
) {
324 ~OpenSSLTLSIOCtx() override
328 static int ticketKeyCb(SSL
*s
, unsigned char keyName
[TLS_TICKETS_KEY_NAME_SIZE
], unsigned char *iv
, EVP_CIPHER_CTX
*ectx
, HMAC_CTX
*hctx
, int enc
)
330 OpenSSLFrontendContext
* ctx
= reinterpret_cast<OpenSSLFrontendContext
*>(libssl_get_ticket_key_callback_data(s
));
331 if (ctx
== nullptr) {
335 int ret
= libssl_ticket_key_callback(s
, ctx
->d_ticketKeys
, keyName
, iv
, ectx
, hctx
, enc
);
337 if (ret
== 0 || ret
== 2) {
338 OpenSSLTLSConnection
* conn
= reinterpret_cast<OpenSSLTLSConnection
*>(SSL_get_ex_data(s
, OpenSSLTLSConnection::s_tlsConnIndex
));
341 conn
->setUnknownTicketKey();
344 conn
->setResumedFromInactiveTicketKey();
353 static int ocspStaplingCb(SSL
* ssl
, void* arg
)
355 if (ssl
== nullptr || arg
== nullptr) {
356 return SSL_TLSEXT_ERR_NOACK
;
358 const auto ocspMap
= reinterpret_cast<std::map
<int, std::string
>*>(arg
);
359 return libssl_ocsp_stapling_callback(ssl
, *ocspMap
);
362 std::unique_ptr
<TLSConnection
> getConnection(int socket
, unsigned int timeout
, time_t now
) override
364 handleTicketsKeyRotation(now
);
366 return std::unique_ptr
<OpenSSLTLSConnection
>(new OpenSSLTLSConnection(socket
, timeout
, d_feContext
));
369 void rotateTicketsKey(time_t now
) override
371 d_feContext
->d_ticketKeys
.rotateTicketsKey(now
);
373 if (d_ticketsKeyRotationDelay
> 0) {
374 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
378 void loadTicketsKeys(const std::string
& keyFile
) override final
380 d_feContext
->d_ticketKeys
.loadTicketsKeys(keyFile
);
382 if (d_ticketsKeyRotationDelay
> 0) {
383 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
387 size_t getTicketsKeysCount() override
389 return d_feContext
->d_ticketKeys
.getKeysCount();
393 std::shared_ptr
<OpenSSLFrontendContext
> d_feContext
;
396 #endif /* HAVE_LIBSSL */
399 #include <gnutls/gnutls.h>
400 #include <gnutls/x509.h>
402 static void safe_memory_lock(void* data
, size_t size
)
404 #ifdef HAVE_LIBSODIUM
405 sodium_mlock(data
, size
);
409 static void safe_memory_release(void* data
, size_t size
)
411 #ifdef HAVE_LIBSODIUM
412 sodium_munlock(data
, size
);
413 #elif defined(HAVE_EXPLICIT_BZERO)
414 explicit_bzero(data
, size
);
415 #elif defined(HAVE_EXPLICIT_MEMSET)
416 explicit_memset(data
, 0, size
);
417 #elif defined(HAVE_GNUTLS_MEMSET)
418 gnutls_memset(data
, 0, size
);
420 /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
421 volatile unsigned int volatile_zero_idx
= 0;
422 volatile unsigned char *p
= reinterpret_cast<volatile unsigned char *>(data
);
428 memset(data
, 0, size
);
429 } while (p
[volatile_zero_idx
] != 0);
433 class GnuTLSTicketsKey
438 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
439 throw std::runtime_error("Error generating tickets key for TLS context");
442 safe_memory_lock(d_key
.data
, d_key
.size
);
445 GnuTLSTicketsKey(const std::string
& keyFile
)
447 /* to be sure we are loading the correct amount of data, which
448 may change between versions, let's generate a correct key first */
449 if (gnutls_session_ticket_key_generate(&d_key
) != GNUTLS_E_SUCCESS
) {
450 throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
453 safe_memory_lock(d_key
.data
, d_key
.size
);
456 ifstream
file(keyFile
);
457 file
.read(reinterpret_cast<char*>(d_key
.data
), d_key
.size
);
461 throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile
);
466 catch (const std::exception
& e
) {
467 safe_memory_release(d_key
.data
, d_key
.size
);
468 gnutls_free(d_key
.data
);
469 d_key
.data
= nullptr;
476 if (d_key
.data
!= nullptr && d_key
.size
> 0) {
477 safe_memory_release(d_key
.data
, d_key
.size
);
479 gnutls_free(d_key
.data
);
480 d_key
.data
= nullptr;
482 const gnutls_datum_t
& getKey() const
488 gnutls_datum_t d_key
{nullptr, 0};
491 class GnuTLSConnection
: public TLSConnection
495 GnuTLSConnection(int socket
, unsigned int timeout
, const gnutls_certificate_credentials_t creds
, const gnutls_priority_t priorityCache
, std::shared_ptr
<GnuTLSTicketsKey
>& ticketsKey
, bool enableTickets
): d_conn(std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(nullptr, gnutls_deinit
)), d_ticketsKey(ticketsKey
)
497 unsigned int sslOptions
= GNUTLS_SERVER
| GNUTLS_NONBLOCK
;
498 #ifdef GNUTLS_NO_SIGNAL
499 sslOptions
|= GNUTLS_NO_SIGNAL
;
504 gnutls_session_t conn
;
505 if (gnutls_init(&conn
, sslOptions
) != GNUTLS_E_SUCCESS
) {
506 throw std::runtime_error("Error creating TLS connection");
509 d_conn
= std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)>(conn
, gnutls_deinit
);
512 if (gnutls_credentials_set(d_conn
.get(), GNUTLS_CRD_CERTIFICATE
, creds
) != GNUTLS_E_SUCCESS
) {
513 throw std::runtime_error("Error setting certificate and key to TLS connection");
516 if (gnutls_priority_set(d_conn
.get(), priorityCache
) != GNUTLS_E_SUCCESS
) {
517 throw std::runtime_error("Error setting ciphers to TLS connection");
520 if (enableTickets
&& d_ticketsKey
) {
521 const gnutls_datum_t
& key
= d_ticketsKey
->getKey();
522 if (gnutls_session_ticket_enable_server(d_conn
.get(), &key
) != GNUTLS_E_SUCCESS
) {
523 throw std::runtime_error("Error setting the tickets key to TLS connection");
527 gnutls_transport_set_int(d_conn
.get(), d_socket
);
529 /* timeouts are in milliseconds */
530 gnutls_handshake_set_timeout(d_conn
.get(), timeout
* 1000);
531 gnutls_record_set_timeout(d_conn
.get(), timeout
* 1000);
534 void doHandshake() override
538 ret
= gnutls_handshake(d_conn
.get());
539 if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
540 throw std::runtime_error("Error accepting a new connection");
543 while (ret
< 0 && ret
== GNUTLS_E_INTERRUPTED
);
546 IOState
tryHandshake() override
551 ret
= gnutls_handshake(d_conn
.get());
552 if (ret
== GNUTLS_E_SUCCESS
) {
553 return IOState::Done
;
555 else if (ret
== GNUTLS_E_AGAIN
) {
556 return IOState::NeedRead
;
558 else if (gnutls_error_is_fatal(ret
) || ret
== GNUTLS_E_WARNING_ALERT_RECEIVED
) {
559 throw std::runtime_error("Error accepting a new connection");
561 } while (ret
== GNUTLS_E_INTERRUPTED
);
563 throw std::runtime_error("Error accepting a new connection");
566 IOState
tryWrite(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toWrite
) override
569 ssize_t res
= gnutls_record_send(d_conn
.get(), reinterpret_cast<const char *>(&buffer
.at(pos
)), toWrite
- pos
);
571 throw std::runtime_error("Error writing to TLS connection");
574 pos
+= static_cast<size_t>(res
);
577 if (gnutls_error_is_fatal(res
)) {
578 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
580 else if (res
== GNUTLS_E_AGAIN
) {
581 return IOState::NeedWrite
;
583 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
586 while (pos
< toWrite
);
587 return IOState::Done
;
590 IOState
tryRead(std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
) override
593 ssize_t res
= gnutls_record_recv(d_conn
.get(), reinterpret_cast<char *>(&buffer
.at(pos
)), toRead
- pos
);
595 throw std::runtime_error("Error reading from TLS connection");
598 pos
+= static_cast<size_t>(res
);
601 if (gnutls_error_is_fatal(res
)) {
602 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res
)));
604 else if (res
== GNUTLS_E_AGAIN
) {
605 return IOState::NeedRead
;
607 warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
610 while (pos
< toRead
);
611 return IOState::Done
;
614 size_t read(void* buffer
, size_t bufferSize
, unsigned int readTimeout
, unsigned int totalTimeout
) override
618 unsigned int remainingTime
= totalTimeout
;
620 start
= time(nullptr);
624 ssize_t res
= gnutls_record_recv(d_conn
.get(), (reinterpret_cast<char *>(buffer
) + got
), bufferSize
- got
);
626 throw std::runtime_error("Error reading from TLS connection");
629 got
+= static_cast<size_t>(res
);
632 if (gnutls_error_is_fatal(res
)) {
633 throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res
)));
635 else if (res
== GNUTLS_E_AGAIN
) {
636 int result
= waitForData(d_socket
, readTimeout
);
638 throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result
));
642 vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res
));
647 time_t now
= time(nullptr);
648 unsigned int elapsed
= now
- start
;
649 if (now
< start
|| elapsed
>= remainingTime
) {
650 throw runtime_error("Timeout while reading data");
653 remainingTime
-= elapsed
;
656 while (got
< bufferSize
);
661 size_t write(const void* buffer
, size_t bufferSize
, unsigned int writeTimeout
) override
666 ssize_t res
= gnutls_record_send(d_conn
.get(), (reinterpret_cast<const char *>(buffer
) + got
), bufferSize
- got
);
668 throw std::runtime_error("Error writing to TLS connection");
671 got
+= static_cast<size_t>(res
);
674 if (gnutls_error_is_fatal(res
)) {
675 throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res
)));
677 else if (res
== GNUTLS_E_AGAIN
) {
678 int result
= waitForRWData(d_socket
, false, writeTimeout
, 0);
680 throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result
));
684 vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res
));
688 while (got
< bufferSize
);
693 std::string
getServerNameIndication() const override
697 size_t name_len
= 256;
699 sni
.resize(name_len
);
701 int res
= gnutls_server_name_get(d_conn
.get(), const_cast<char*>(sni
.c_str()), &name_len
, &type
, 0);
702 if (res
== GNUTLS_E_SUCCESS
) {
703 sni
.resize(name_len
);
707 return std::string();
710 LibsslTLSVersion
getTLSVersion() const override
712 auto proto
= gnutls_protocol_get_version(d_conn
.get());
715 return LibsslTLSVersion::TLS10
;
717 return LibsslTLSVersion::TLS11
;
719 return LibsslTLSVersion::TLS12
;
720 #if GNUTLS_VERSION_NUMBER >= 0x030603
722 return LibsslTLSVersion::TLS13
;
723 #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
725 return LibsslTLSVersion::Unknown
;
729 bool hasSessionBeenResumed() const override
732 return gnutls_session_is_resumed(d_conn
.get()) != 0;
737 void close() override
740 gnutls_bye(d_conn
.get(), GNUTLS_SHUT_WR
);
745 std::unique_ptr
<gnutls_session_int
, void(*)(gnutls_session_t
)> d_conn
;
746 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
;
749 class GnuTLSIOCtx
: public TLSCtx
752 GnuTLSIOCtx(TLSFrontend
& fe
): d_creds(std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)>(nullptr, gnutls_certificate_free_credentials
)), d_enableTickets(fe
.d_tlsConfig
.d_enableTickets
)
755 d_ticketsKeyRotationDelay
= fe
.d_tlsConfig
.d_ticketsKeyRotationDelay
;
757 gnutls_certificate_credentials_t creds
;
758 rc
= gnutls_certificate_allocate_credentials(&creds
);
759 if (rc
!= GNUTLS_E_SUCCESS
) {
760 throw std::runtime_error("Error allocating credentials for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
763 d_creds
= std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)>(creds
, gnutls_certificate_free_credentials
);
766 for (const auto& pair
: fe
.d_tlsConfig
.d_certKeyPairs
) {
767 rc
= gnutls_certificate_set_x509_key_file(d_creds
.get(), pair
.first
.c_str(), pair
.second
.c_str(), GNUTLS_X509_FMT_PEM
);
768 if (rc
!= GNUTLS_E_SUCCESS
) {
769 throw std::runtime_error("Error loading certificate ('" + pair
.first
+ "') and key ('" + pair
.second
+ "') for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
774 for (const auto& file
: fe
.d_tlsConfig
.d_ocspFiles
) {
775 rc
= gnutls_certificate_set_ocsp_status_request_file(d_creds
.get(), file
.c_str(), count
);
776 if (rc
!= GNUTLS_E_SUCCESS
) {
777 throw std::runtime_error("Error loading OCSP response from file '" + file
+ "' for certificate ('" + fe
.d_tlsConfig
.d_certKeyPairs
.at(count
).first
+ "') and key ('" + fe
.d_tlsConfig
.d_certKeyPairs
.at(count
).second
+ "') for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
782 #if GNUTLS_VERSION_NUMBER >= 0x030600
783 rc
= gnutls_certificate_set_known_dh_params(d_creds
.get(), GNUTLS_SEC_PARAM_HIGH
);
784 if (rc
!= GNUTLS_E_SUCCESS
) {
785 throw std::runtime_error("Error setting DH params for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + gnutls_strerror(rc
));
789 rc
= gnutls_priority_init(&d_priorityCache
, fe
.d_tlsConfig
.d_ciphers
.empty() ? "NORMAL" : fe
.d_tlsConfig
.d_ciphers
.c_str(), nullptr);
790 if (rc
!= GNUTLS_E_SUCCESS
) {
791 throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe
.d_tlsConfig
.d_ciphers
+ "' (" + gnutls_strerror(rc
) + ") on " + fe
.d_addr
.toStringWithPort());
795 if (fe
.d_tlsConfig
.d_ticketKeyFile
.empty()) {
796 handleTicketsKeyRotation(time(nullptr));
799 GnuTLSIOCtx::loadTicketsKeys(fe
.d_tlsConfig
.d_ticketKeyFile
);
802 catch(const std::runtime_error
& e
) {
803 throw std::runtime_error("Error generating tickets key for TLS context on " + fe
.d_addr
.toStringWithPort() + ": " + e
.what());
807 virtual ~GnuTLSIOCtx() override
811 if (d_priorityCache
) {
812 gnutls_priority_deinit(d_priorityCache
);
816 std::unique_ptr
<TLSConnection
> getConnection(int socket
, unsigned int timeout
, time_t now
) override
818 handleTicketsKeyRotation(now
);
820 std::shared_ptr
<GnuTLSTicketsKey
> ticketsKey
;
822 ReadLock
rl(&d_lock
);
823 ticketsKey
= d_ticketsKey
;
826 return std::unique_ptr
<GnuTLSConnection
>(new GnuTLSConnection(socket
, timeout
, d_creds
.get(), d_priorityCache
, ticketsKey
, d_enableTickets
));
829 void rotateTicketsKey(time_t now
) override
831 if (!d_enableTickets
) {
835 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>();
838 WriteLock
wl(&d_lock
);
839 d_ticketsKey
= newKey
;
842 if (d_ticketsKeyRotationDelay
> 0) {
843 d_ticketsKeyNextRotation
= now
+ d_ticketsKeyRotationDelay
;
847 void loadTicketsKeys(const std::string
& file
) override final
849 if (!d_enableTickets
) {
853 auto newKey
= std::make_shared
<GnuTLSTicketsKey
>(file
);
855 WriteLock
wl(&d_lock
);
856 d_ticketsKey
= newKey
;
859 if (d_ticketsKeyRotationDelay
> 0) {
860 d_ticketsKeyNextRotation
= time(nullptr) + d_ticketsKeyRotationDelay
;
864 size_t getTicketsKeysCount() override
866 ReadLock
rl(&d_lock
);
867 return d_ticketsKey
!= nullptr ? 1 : 0;
871 std::unique_ptr
<gnutls_certificate_credentials_st
, void(*)(gnutls_certificate_credentials_t
)> d_creds
;
872 gnutls_priority_t d_priorityCache
{nullptr};
873 std::shared_ptr
<GnuTLSTicketsKey
> d_ticketsKey
{nullptr};
874 ReadWriteLock d_lock
;
875 bool d_enableTickets
{true};
878 #endif /* HAVE_GNUTLS */
880 #endif /* HAVE_DNS_OVER_TLS */
882 bool TLSFrontend::setupTLS()
884 #ifdef HAVE_DNS_OVER_TLS
885 /* get the "best" available provider */
886 if (!d_provider
.empty()) {
888 if (d_provider
== "gnutls") {
889 d_ctx
= std::make_shared
<GnuTLSIOCtx
>(*this);
892 #endif /* HAVE_GNUTLS */
894 if (d_provider
== "openssl") {
895 d_ctx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
898 #endif /* HAVE_LIBSSL */
901 d_ctx
= std::make_shared
<OpenSSLTLSIOCtx
>(*this);
902 #else /* HAVE_LIBSSL */
904 d_ctx
= std::make_shared
<GnuTLSIOCtx
>(*this);
905 #endif /* HAVE_GNUTLS */
906 #endif /* HAVE_LIBSSL */
908 #endif /* HAVE_DNS_OVER_TLS */